├── LICENSE
├── README.md
├── assets
└── overview.png
├── configs
├── default.yaml
└── gefu
│ ├── dtu
│ ├── scan1.yaml
│ ├── scan103.yaml
│ ├── scan114.yaml
│ ├── scan21.yaml
│ └── scan8.yaml
│ ├── dtu_pretrain.yaml
│ ├── llff
│ ├── fern.yaml
│ ├── flower.yaml
│ ├── fortress.yaml
│ ├── horns.yaml
│ ├── leaves.yaml
│ ├── orchids.yaml
│ ├── room.yaml
│ └── trex.yaml
│ ├── llff_eval.yaml
│ ├── nerf
│ ├── chair.yaml
│ ├── drums.yaml
│ ├── ficus.yaml
│ ├── hotdog.yaml
│ ├── lego.yaml
│ ├── materials.yaml
│ ├── mic.yaml
│ └── ship.yaml
│ └── nerf_eval.yaml
├── data
└── mvsnerf
│ ├── dtu_train_all.txt
│ ├── dtu_val_all.txt
│ └── pairs.th
├── lib
├── __init__.py
├── config
│ ├── __init__.py
│ ├── config.py
│ └── yacs.py
├── datasets
│ ├── __init__.py
│ ├── collate_batch.py
│ ├── dtu
│ │ └── gefu.py
│ ├── gefu_utils.py
│ ├── llff
│ │ └── gefu.py
│ ├── make_dataset.py
│ ├── nerf
│ │ └── gefu.py
│ ├── samplers.py
│ └── video_utils.py
├── evaluators
│ ├── __init__.py
│ ├── gefu.py
│ └── make_evaluator.py
├── networks
│ ├── __init__.py
│ ├── gefu
│ │ ├── cost_reg_net.py
│ │ ├── feature_net.py
│ │ ├── nerf.py
│ │ ├── network.py
│ │ ├── res_unet.py
│ │ └── utils.py
│ └── make_network.py
├── train
│ ├── __init__.py
│ ├── losses
│ │ ├── gefu.py
│ │ ├── nerf.py
│ │ ├── ssim_loss.py
│ │ └── vgg_perceptual_loss.py
│ ├── optimizer.py
│ ├── recorder.py
│ ├── scheduler.py
│ └── trainers
│ │ ├── __init__.py
│ │ ├── make_trainer.py
│ │ └── trainer.py
└── utils
│ ├── base_utils.py
│ ├── data_config.py
│ ├── data_utils.py
│ ├── img_utils.py
│ ├── mask_utils.py
│ ├── net_utils.py
│ ├── optimizer
│ ├── lr_scheduler.py
│ └── radam.py
│ ├── ply_utils.py
│ ├── rend_utils.py
│ └── vis_utils.py
├── requirements.txt
├── run.py
└── train_net.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Liu Tianqi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Geometry-aware Reconstruction and Fusion-refined Rendering for Generalizable Neural Radiance Fields
2 |
3 | PyTorch implementation of paper "Geometry-aware Reconstruction and Fusion-refined Rendering for Generalizable Neural Radiance Fields", CVPR 2024.
4 |
5 | > [Geometry-aware Reconstruction and Fusion-refined Rendering for Generalizable Neural Radiance Fields](https://arxiv.org/abs/2404.17528)
6 | > Tianqi Liu, Xinyi Ye, Min Shi, Zihao Huang, Zhiyu Pan, Zhan Peng, Zhiguo Cao* \
7 | > CVPR 2024
8 | > [project page](https://gefucvpr24.github.io/) | [paper](https://arxiv.org/abs/2404.17528) | [poster](https://cvpr.thecvf.com/virtual/2024/poster/29889) | [model](https://drive.google.com/drive/folders/1pCCOLUj2fNAbp0ZXj7_tEPfF1vZvMbh-?usp=drive_link)
9 |
10 |
11 | ## Introduction
12 | Generalizable NeRF aims to synthesize novel views for unseen scenes. Common practices involve constructing variance-based cost volumes for geometry reconstruction and encoding 3D descriptors for decoding novel views. However, existing methods show limited generalization ability in challenging conditions due to inaccurate geometry, sub-optimal descriptors, and decoding strategies. We address these issues point by point. First, we find the variance-based cost volume exhibits failure patterns as the features of pixels corresponding to the same point can be inconsistent across different views due to occlusions or reflections. We introduce an Adaptive Cost Aggregation (ACA) approach to amplify the contribution of consistent pixel pairs and suppress inconsistent ones. Unlike previous methods that solely fuse 2D features into descriptors, our approach introduces a Spatial-View Aggregator (SVA) to incorporate 3D context into descriptors through spatial and inter-view interaction. When decoding the descriptors, we observe the two existing decoding strategies excel in different areas, which are complementary. A Consistency-Aware Fusion (CAF) strategy is proposed to leverage the advantages of both. We incorporate the above ACA, SVA, and CAF into a coarse-to-fine framework, termed Geometry-aware Reconstruction and Fusion-refined Rendering (GeFu). GeFu attains state-of-the-art performance across multiple datasets.
13 |
14 |
15 |
16 |
17 |
18 | ## Installation
19 |
20 | ### Clone this repository:
21 |
22 | ```
23 | git clone https://github.com/TQTQliu/GeFu.git
24 | cd GeFu
25 | ```
26 |
27 | ### Set up the python environment
28 |
29 | ```
30 | conda create -n gefu python=3.8
31 | conda activate gefu
32 | pip install -r requirements.txt
33 | pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
34 | ```
35 |
36 |
37 | ## Datasets
38 |
39 |
40 | #### 1. DTU
41 |
42 | **Training data**. Download [DTU training data](https://drive.google.com/file/d/1eDjh-_bxKKnEuz5h-HXS7EDJn59clx6V/view) and [Depth raw](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/cascade-stereo/CasMVSNet/dtu_data/dtu_train_hr/Depths_raw.zip). Unzip and organize them as:
43 | ```
44 | mvs_training
45 | ├── dtu
46 | ├── Cameras
47 | ├── Depths
48 | ├── Depths_raw
49 | └── Rectified
50 | ```
51 |
52 | #### 2. NeRF Synthetic (Blender) and Real Forward-facing (LLFF)
53 |
54 | Download the [NeRF Synthetic](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) and [Real Forward-facing](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) datasets and unzip them.
55 |
56 |
57 | ## Usage
58 | ### Train generalizable model
59 |
60 | To train a generalizable model from scratch on DTU, specify ``data_root`` in ``configs/gefu/dtu_pretrain.yaml`` first and then run:
61 | ```
62 | python train_net.py --cfg_file configs/gefu/dtu_pretrain.yaml
63 | ```
64 |
65 | Our code also supports multi-gpu training. The released pretrained model was trained with 4 GPUs.
66 | ```
67 | python -m torch.distributed.launch --nproc_per_node=4 train_net.py --cfg_file configs/gefu/dtu_pretrain.yaml distributed True gpus 0,1,2,3
68 | ```
69 |
70 |
71 |
72 | ### Per-scene optimization
73 | Here we take the scan1 on the DTU as an example:
74 | ```
75 | cd ./trained_model/gefu
76 | mkdir dtu_ft_scan1
77 | cp dtu_pretrain/latest.pth dtu_ft_scan1
78 | cd ../..
79 | python train_net.py --cfg_file configs/gefu/dtu/scan1.yaml
80 | ```
81 |
82 | We provide the finetuned models for each scenes [here](https://drive.google.com/drive/folders/11X_YI4BmYoRG1Q8AYnOnvPQmqQufqSMo?usp=drive_link).
83 |
84 |
85 |
86 | ### Evaluation
87 |
88 | #### Evaluate the pretrained model on DTU
89 |
90 | Download the [pretrained model](https://drive.google.com/drive/folders/1pCCOLUj2fNAbp0ZXj7_tEPfF1vZvMbh-?usp=drive_link) and put it into `trained_model/gefu/dtu_pretrain/latest.pth`
91 |
92 | Use the following command to evaluate the pretrained model on DTU:
93 | ```
94 | python run.py --type evaluate --cfg_file configs/gefu/dtu_pretrain.yaml gefu.eval_depth True
95 | ```
96 | The rendering images will be saved in ```result/gefu/dtu_pretrain```. Add the ```save_video True``` parameter at the end of the command to save the rendering videos.
97 |
98 | #### Evaluate the pretrained model on Real Forward-facing
99 |
100 | ```
101 | python run.py --type evaluate --cfg_file configs/gefu/llff_eval.yaml
102 | ```
103 |
104 | #### Evaluate the pretrained model on NeRF Synthetic
105 | ```
106 | python run.py --type evaluate --cfg_file configs/gefu/nerf_eval.yaml
107 | ```
108 |
109 |
110 |
111 |
112 | ## Citation
113 | If you find our work useful for your research, please cite our paper.
114 |
115 | ```
116 | @InProceedings{Liu_2024_CVPR,
117 | author = {Liu, Tianqi and Ye, Xinyi and Shi, Min and Huang, Zihao and Pan, Zhiyu and Peng, Zhan and Cao, Zhiguo},
118 | title = {Geometry-aware Reconstruction and Fusion-refined Rendering for Generalizable Neural Radiance Fields},
119 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
120 | month = {June},
121 | year = {2024},
122 | pages = {7654-7663}
123 | }
124 | ```
125 |
126 | ## Relevant Works
127 |
128 | - [**PixelNeRF: Neural Radiance Fields from One or Few Images**](https://alexyu.net/pixelnerf/), CVPR 2021
129 |
130 | - [**IBRNet: Learning Multi-View Image-Based Rendering**](https://ibrnet.github.io/), CVPR 2021
131 |
132 | - [**MVSNeRF: Fast Generalizable Radiance Field Reconstruction from Multi-View Stereo**](https://apchenstu.github.io/mvsnerf/), ICCV 2021
133 |
134 | - [**Neural Rays for Occlusion-aware Image-based Rendering**](https://liuyuan-pal.github.io/NeuRay/), CVPR 2022
135 |
136 | - [**ENeRF: Efficient Neural Radiance Fields for Interactive Free-viewpoint Video**](https://zju3dv.github.io/enerf/), SIGGRAPH Asia 2022
137 |
138 | - [**Is Attention All NeRF Needs?**](https://vita-group.github.io/GNT/), ICLR 2023
139 |
140 | - [**Explicit Correspondence Matching for Generalizable Neural Radiance Fields**](https://donydchen.github.io/matchnerf/), arXiv 2023
141 |
142 |
143 | ## Acknowledgement
144 |
145 | The project is mainly based on [ENeRF](https://github.com/zju3dv/ENeRF?tab=readme-ov-file). Many thanks for their excellent contributions! When using our code, please also pay attention to the license of ENeRF.
146 |
--------------------------------------------------------------------------------
/assets/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TQTQliu/GeFu/86443625c6ae69a531b81a2e672d3922d82f0e70/assets/overview.png
--------------------------------------------------------------------------------
/configs/default.yaml:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TQTQliu/GeFu/86443625c6ae69a531b81a2e672d3922d82f0e70/configs/default.yaml
--------------------------------------------------------------------------------
/configs/gefu/dtu/scan1.yaml:
--------------------------------------------------------------------------------
1 | parent_cfg: configs/gefu/dtu_pretrain.yaml
2 | exp_name: dtu_ft_scan1
3 | gefu:
4 | test_input_views: 4
5 | train_dataset:
6 | scene: scan1
7 | test_dataset:
8 | scene: scan1
9 | train:
10 | epoch: 222 # pretrained epoch + 6
11 | save_ep: 1
12 | eval_ep: 1
13 |
--------------------------------------------------------------------------------
/configs/gefu/dtu/scan103.yaml:
--------------------------------------------------------------------------------
1 | parent_cfg: configs/gefu/dtu_pretrain.yaml
2 | exp_name: dtu_ft_scan103
3 | gefu:
4 | test_input_views: 4
5 | train_dataset:
6 | scene: scan103
7 | test_dataset:
8 | scene: scan103
9 | train:
10 | epoch: 222 # pretrained epoch + 6
11 | lr: 5e-5
12 | save_ep: 1
13 | eval_ep: 1
14 |
--------------------------------------------------------------------------------
/configs/gefu/dtu/scan114.yaml:
--------------------------------------------------------------------------------
1 | parent_cfg: configs/gefu/dtu_pretrain.yaml
2 | exp_name: dtu_ft_scan114
3 | gefu:
4 | test_input_views: 4
5 | train_dataset:
6 | scene: scan114
7 | test_dataset:
8 | scene: scan114
9 | train:
10 | epoch: 222 # pretrained epoch + 6
11 | lr: 5e-5
12 | save_ep: 1
13 | eval_ep: 1
14 |
--------------------------------------------------------------------------------
/configs/gefu/dtu/scan21.yaml:
--------------------------------------------------------------------------------
1 | parent_cfg: configs/gefu/dtu_pretrain.yaml
2 | exp_name: dtu_ft_scan21
3 | gefu:
4 | test_input_views: 4
5 | train_dataset:
6 | scene: scan21
7 | test_dataset:
8 | scene: scan21
9 | train:
10 | epoch: 222 # pretrained epoch + 6
11 | lr: 5e-6
12 | save_ep: 1
13 | eval_ep: 1
14 |
--------------------------------------------------------------------------------
/configs/gefu/dtu/scan8.yaml:
--------------------------------------------------------------------------------
1 | parent_cfg: configs/gefu/dtu_pretrain.yaml
2 | exp_name: dtu_ft_scan8
3 | gefu:
4 | test_input_views: 4
5 | train_dataset:
6 | scene: scan8
7 | test_dataset:
8 | scene: scan8
9 | train:
10 | epoch: 222 # pretrained epoch + 6
11 | save_ep: 1
12 | eval_ep: 1
13 |
14 |
--------------------------------------------------------------------------------
/configs/gefu/dtu_pretrain.yaml:
--------------------------------------------------------------------------------
1 | task: gefu
2 | gpus: [0]
3 | exp_name: 'dtu_pretrain'
4 |
5 |
6 | # module
7 | train_dataset_module: lib.datasets.dtu.gefu
8 | test_dataset_module: lib.datasets.dtu.gefu
9 | network_module: lib.networks.gefu.network
10 | loss_module: lib.train.losses.gefu
11 | evaluator_module: lib.evaluators.gefu
12 |
13 | save_result: True
14 | eval_lpips: True
15 | save_video: False
16 | save_ply: False
17 |
18 |
19 | # task config
20 | gefu:
21 | train_input_views: [2, 3, 4]
22 | train_input_views_prob: [0.1, 0.8, 0.1]
23 | test_input_views: 3
24 | viewdir_agg: True
25 | chunk_size: 1000000
26 | white_bkgd: False
27 | eval_depth: False
28 | eval_center: False # only for llff evaluation (same as MVSNeRF: https://github.com/apchenstu/mvsnerf/blob/1fdf6487389d0872dade614b3cea61f7b099406e/renderer.ipynb)
29 | reweighting: False
30 | cas_config:
31 | num: 2
32 | depth_inv: [True, False]
33 | volume_scale: [0.125, 0.5]
34 | volume_planes: [64, 8]
35 | im_feat_scale: [0.25, 0.5]
36 | im_ibr_scale: [0.25, 1.]
37 | render_scale: [0.25, 1.0]
38 | render_im_feat_level: [0, 2]
39 | nerf_model_feat_ch: [32, 8]
40 | render_if: [True, True]
41 | num_samples: [8, 2]
42 | num_rays: [4096, 32768]
43 | num_patchs: [0, 0]
44 | train_img: [True, True]
45 | patch_size: [-1, -1]
46 | loss_weight: [0.5, 1.]
47 |
48 |
49 |
50 | train_dataset:
51 | data_root: 'mvs_training/dtu/'
52 | ann_file: 'data/mvsnerf/dtu_train_all.txt'
53 | split: 'train'
54 | batch_size: 2
55 | input_ratio: 1.
56 |
57 | test_dataset:
58 | data_root: 'mvs_training/dtu/'
59 | ann_file: 'data/mvsnerf/dtu_val_all.txt'
60 | split: 'test'
61 | batch_size: 1
62 | input_ratio: 1.
63 |
64 | train:
65 | batch_size: 1
66 | lr: 5e-4
67 | weight_decay: 0.
68 | epoch: 300
69 | scheduler:
70 | type: 'exponential'
71 | gamma: 0.5
72 | decay_epochs: 50
73 | batch_sampler: 'gefu'
74 | collator: 'gefu'
75 | sampler_meta:
76 | input_views_num: [2, 3, 4]
77 | input_views_prob: [0.1, 0.8, 0.1]
78 | num_workers: 4
79 |
80 | test:
81 | batch_size: 1
82 | collator: 'gefu'
83 | batch_sampler: 'gefu'
84 | sampler_meta:
85 | input_views_num: [3]
86 | input_views_prob: [1.]
87 |
88 | ep_iter: 1000
89 | save_ep: 1
90 | eval_ep: 1
91 | save_latest_ep: 1
92 | log_interval: 1
93 |
--------------------------------------------------------------------------------
/configs/gefu/llff/fern.yaml:
--------------------------------------------------------------------------------
1 | parent_cfg: configs/gefu/llff_eval.yaml
2 | exp_name: llff_ft_fern
3 |
4 | gefu:
5 | test_input_views: 4
6 | train_input_views: [3, 4]
7 | train_input_views_prob: [0.4, 0.6]
8 | cas_config:
9 | render_if: [True, True]
10 | train_dataset:
11 | scene: fern
12 | test_dataset:
13 | scene: fern
14 | train:
15 | epoch: 222 # pretrained epoch + 6
16 | lr: 5e-5
17 | sampler_meta:
18 | input_views_num: [3, 4]
19 | input_views_prob: [0.4, 0.6]
20 | save_ep: 1
21 | eval_ep: 1
22 |
--------------------------------------------------------------------------------
/configs/gefu/llff/flower.yaml:
--------------------------------------------------------------------------------
1 | parent_cfg: configs/gefu/llff_eval.yaml
2 | exp_name: llff_ft_flower
3 |
4 | gefu:
5 | test_input_views: 4
6 | train_input_views: [3, 4]
7 | train_input_views_prob: [0.4, 0.6]
8 | cas_config:
9 | render_if: [True, True]
10 | train_dataset:
11 | scene: flower
12 | test_dataset:
13 | scene: flower
14 | train:
15 | epoch: 222 # pretrained epoch + 6
16 | lr: 5e-5
17 | sampler_meta:
18 | input_views_num: [3, 4]
19 | input_views_prob: [0.4, 0.6]
20 | save_ep: 1
21 | eval_ep: 1
22 |
--------------------------------------------------------------------------------
/configs/gefu/llff/fortress.yaml:
--------------------------------------------------------------------------------
1 | parent_cfg: configs/gefu/llff_eval.yaml
2 | exp_name: llff_ft_fortress
3 |
4 | gefu:
5 | test_input_views: 4
6 | train_input_views: [3, 4]
7 | train_input_views_prob: [0.4, 0.6]
8 | cas_config:
9 | render_if: [True, True]
10 | train_dataset:
11 | scene: fortress
12 | test_dataset:
13 | scene: fortress
14 | train:
15 | epoch: 222 # pretrained epoch + 6
16 | lr: 5e-6
17 | sampler_meta:
18 | input_views_num: [3, 4]
19 | input_views_prob: [0.4, 0.6]
20 | save_ep: 1
21 | eval_ep: 1
22 |
--------------------------------------------------------------------------------
/configs/gefu/llff/horns.yaml:
--------------------------------------------------------------------------------
1 | parent_cfg: configs/gefu/llff_eval.yaml
2 | exp_name: llff_ft_horns
3 |
4 | gefu:
5 | test_input_views: 4
6 | train_input_views: [3, 4]
7 | train_input_views_prob: [0.4, 0.6]
8 | cas_config:
9 | render_if: [True, True]
10 | train_dataset:
11 | scene: horns
12 | test_dataset:
13 | scene: horns
14 | train:
15 | epoch: 222 # pretrained epoch + 6
16 | lr: 5e-5
17 | sampler_meta:
18 | input_views_num: [3, 4]
19 | input_views_prob: [0.4, 0.6]
20 | save_ep: 1
21 | eval_ep: 1
22 |
--------------------------------------------------------------------------------
/configs/gefu/llff/leaves.yaml:
--------------------------------------------------------------------------------
1 | parent_cfg: configs/gefu/llff_eval.yaml
2 | exp_name: llff_ft_leaves
3 |
4 | gefu:
5 | test_input_views: 4
6 | train_input_views: [3, 4]
7 | train_input_views_prob: [0.4, 0.6]
8 | cas_config:
9 | render_if: [True, True]
10 | train_dataset:
11 | scene: leaves
12 | test_dataset:
13 | scene: leaves
14 | train:
15 | epoch: 222 # pretrained epoch + 6
16 | lr: 5e-5
17 | sampler_meta:
18 | input_views_num: [3, 4]
19 | input_views_prob: [0.4, 0.6]
20 | save_ep: 1
21 | eval_ep: 1
22 |
--------------------------------------------------------------------------------
/configs/gefu/llff/orchids.yaml:
--------------------------------------------------------------------------------
1 | parent_cfg: configs/gefu/llff_eval.yaml
2 | exp_name: llff_ft_orchids
3 |
4 | gefu:
5 | test_input_views: 4
6 | train_input_views: [3, 4]
7 | train_input_views_prob: [0.4, 0.6]
8 | cas_config:
9 | render_if: [True, True]
10 | train_dataset:
11 | scene: orchids
12 | test_dataset:
13 | scene: orchids
14 | train:
15 | epoch: 222 # pretrained epoch + 6
16 | lr: 5e-5
17 | sampler_meta:
18 | input_views_num: [3, 4]
19 | input_views_prob: [0.4, 0.6]
20 | save_ep: 1
21 | eval_ep: 1
22 |
--------------------------------------------------------------------------------
/configs/gefu/llff/room.yaml:
--------------------------------------------------------------------------------
1 | parent_cfg: configs/gefu/llff_eval.yaml
2 | exp_name: llff_ft_room
3 |
4 | gefu:
5 | test_input_views: 4
6 | train_input_views: [3, 4]
7 | train_input_views_prob: [0.4, 0.6]
8 | cas_config:
9 | render_if: [True, True]
10 | train_dataset:
11 | scene: room
12 | test_dataset:
13 | scene: room
14 | train:
15 | epoch: 222 # pretrained epoch + 6
16 | lr: 1e-3
17 | sampler_meta:
18 | input_views_num: [3, 4]
19 | input_views_prob: [0.4, 0.6]
20 | save_ep: 1
21 | eval_ep: 1
22 |
--------------------------------------------------------------------------------
/configs/gefu/llff/trex.yaml:
--------------------------------------------------------------------------------
1 | parent_cfg: configs/gefu/llff_eval.yaml
2 | exp_name: llff_ft_trex
3 |
4 | gefu:
5 | test_input_views: 4
6 | train_input_views: [3, 4]
7 | train_input_views_prob: [0.4, 0.6]
8 | cas_config:
9 | render_if: [True, True]
10 | train_dataset:
11 | scene: trex
12 | test_dataset:
13 | scene: trex
14 | train:
15 | epoch: 222 # pretrained epoch + 6
16 | lr: 5e-4
17 | sampler_meta:
18 | input_views_num: [3, 4]
19 | input_views_prob: [0.4, 0.6]
20 | save_ep: 1
21 | eval_ep: 1
22 |
--------------------------------------------------------------------------------
/configs/gefu/llff_eval.yaml:
--------------------------------------------------------------------------------
1 | parent_cfg: configs/gefu/dtu_pretrain.yaml
2 |
3 | train_dataset_module: lib.datasets.llff.gefu
4 | test_dataset_module: lib.datasets.llff.gefu
5 |
6 | gefu:
7 | eval_center: True
8 | reweighting: True
9 | cas_config:
10 | render_if: [True, True]
11 | volume_planes: [32, 8]
12 |
13 | train_dataset:
14 | data_root: 'nerf_llff_data'
15 | split: 'train'
16 | # input_h_w: [640, 960] # OOM for RTX 3090
17 | input_h_w: [512, 640]
18 | batch_size: 1
19 | input_ratio: 1.
20 |
21 | test_dataset:
22 | data_root: 'nerf_llff_data'
23 | split: 'test'
24 | batch_size: 1
25 | input_h_w: [640, 960]
26 | input_ratio: 1.
27 |
--------------------------------------------------------------------------------
/configs/gefu/nerf/chair.yaml:
--------------------------------------------------------------------------------
1 | parent_cfg: configs/gefu/nerf_eval.yaml
2 | exp_name: nerf_ft_chair
3 | gefu:
4 | test_input_views: 4
5 | train_input_views: [3, 4]
6 | train_input_views_prob: [0.4, 0.6]
7 | cas_config:
8 | render_if: [True, True]
9 | train_dataset:
10 | scene: chair
11 | test_dataset:
12 | scene: chair
13 | train:
14 | epoch: 222 # pretrained epoch + 6
15 | lr: 3e-4
16 | save_ep: 1
17 | eval_ep: 1
18 |
--------------------------------------------------------------------------------
/configs/gefu/nerf/drums.yaml:
--------------------------------------------------------------------------------
1 | parent_cfg: configs/gefu/nerf_eval.yaml
2 | exp_name: nerf_ft_drums
3 | gefu:
4 | test_input_views: 4
5 | train_input_views: [3, 4]
6 | train_input_views_prob: [0.4, 0.6]
7 | cas_config:
8 | render_if: [True, True]
9 | train_dataset:
10 | scene: drums
11 | test_dataset:
12 | scene: drums
13 | train:
14 | epoch: 222 # pretrained epoch + 6
15 | lr: 5e-5
16 | save_ep: 1
17 | eval_ep: 1
18 |
--------------------------------------------------------------------------------
/configs/gefu/nerf/ficus.yaml:
--------------------------------------------------------------------------------
1 | parent_cfg: configs/gefu/nerf_eval.yaml
2 | exp_name: nerf_ft_ficus
3 | gefu:
4 | test_input_views: 4
5 | train_input_views: [3, 4]
6 | train_input_views_prob: [0.4, 0.6]
7 | cas_config:
8 | render_if: [True, True]
9 | train_dataset:
10 | scene: ficus
11 | test_dataset:
12 | scene: ficus
13 | train:
14 | epoch: 225 # pretrained epoch + 6
15 | lr: 8e-5
16 | save_ep: 1
17 | eval_ep: 1
18 |
--------------------------------------------------------------------------------
/configs/gefu/nerf/hotdog.yaml:
--------------------------------------------------------------------------------
1 | parent_cfg: configs/gefu/nerf_eval.yaml
2 | exp_name: nerf_ft_hotdog
3 | gefu:
4 | test_input_views: 4
5 | train_input_views: [3, 4]
6 | train_input_views_prob: [0.4, 0.6]
7 | cas_config:
8 | render_if: [True, True]
9 | train_dataset:
10 | scene: hotdog
11 | test_dataset:
12 | scene: hotdog
13 | train:
14 | epoch: 222 # pretrained epoch + 6
15 | lr: 5e-4
16 | save_ep: 1
17 | eval_ep: 1
18 |
--------------------------------------------------------------------------------
/configs/gefu/nerf/lego.yaml:
--------------------------------------------------------------------------------
1 | parent_cfg: configs/gefu/nerf_eval.yaml
2 | exp_name: nerf_ft_lego
3 | gefu:
4 | test_input_views: 4
5 | train_input_views: [3, 4]
6 | train_input_views_prob: [0.4, 0.6]
7 | cas_config:
8 | render_if: [True, True]
9 | train_dataset:
10 | scene: lego
11 | test_dataset:
12 | scene: lego
13 | train:
14 | epoch: 222 # pretrained epoch + 6
15 | lr: 5e-4
16 | save_ep: 1
17 | eval_ep: 1
18 |
--------------------------------------------------------------------------------
/configs/gefu/nerf/materials.yaml:
--------------------------------------------------------------------------------
1 | parent_cfg: configs/gefu/nerf_eval.yaml
2 | exp_name: nerf_ft_materials
3 | gefu:
4 | test_input_views: 4
5 | train_input_views: [3, 4]
6 | train_input_views_prob: [0.4, 0.6]
7 | cas_config:
8 | render_if: [True, True]
9 | train_dataset:
10 | scene: materials
11 | test_dataset:
12 | scene: materials
13 | train:
14 | epoch: 222 # pretrained epoch + 6
15 | lr: 6e-4
16 | save_ep: 1
17 | eval_ep: 1
18 |
--------------------------------------------------------------------------------
/configs/gefu/nerf/mic.yaml:
--------------------------------------------------------------------------------
1 | parent_cfg: configs/gefu/nerf_eval.yaml
2 | exp_name: nerf_ft_mic
3 | gefu:
4 | test_input_views: 4
5 | train_input_views: [3, 4]
6 | train_input_views_prob: [0.4, 0.6]
7 | cas_config:
8 | render_if: [True, True]
9 | train_dataset:
10 | scene: mic
11 | test_dataset:
12 | scene: mic
13 | train:
14 | epoch: 222 # pretrained epoch + 6
15 | lr: 4e-4
16 | save_ep: 1
17 | eval_ep: 1
18 |
--------------------------------------------------------------------------------
/configs/gefu/nerf/ship.yaml:
--------------------------------------------------------------------------------
1 | parent_cfg: configs/gefu/nerf_eval.yaml
2 | exp_name: nerf_ft_ship
3 | gefu:
4 | test_input_views: 4
5 | train_input_views: [3, 4]
6 | train_input_views_prob: [0.4, 0.6]
7 | cas_config:
8 | render_if: [True, True]
9 | train_dataset:
10 | scene: ship
11 | test_dataset:
12 | scene: ship
13 | train:
14 | epoch: 222 # pretrained epoch + 6
15 | lr: 8e-4
16 | save_ep: 1
17 | eval_ep: 1
18 |
--------------------------------------------------------------------------------
/configs/gefu/nerf_eval.yaml:
--------------------------------------------------------------------------------
1 | parent_cfg: configs/gefu/dtu_pretrain.yaml
2 |
3 | train_dataset_module: lib.datasets.nerf.gefu
4 | test_dataset_module: lib.datasets.nerf.gefu
5 |
6 | gefu:
7 | reweighting: True
8 | cas_config:
9 | render_if: [True, True]
10 |
11 | train_dataset:
12 | data_root: 'nerf_synthetic'
13 | split: 'train'
14 | # input_h_w: [800, 800] # OOM for RTX 3090
15 | input_h_w: [512, 640]
16 | batch_size: 1
17 | input_ratio: 1.
18 |
19 | test_dataset:
20 | data_root: 'nerf_synthetic'
21 | split: 'test'
22 | input_h_w: [800, 800]
23 | batch_size: 1
24 | input_ratio: 1.
25 |
26 |
27 |
--------------------------------------------------------------------------------
/data/mvsnerf/dtu_train_all.txt:
--------------------------------------------------------------------------------
1 | scan3
2 | scan4
3 | scan5
4 | scan6
5 | scan9
6 | scan10
7 | scan11
8 | scan12
9 | scan13
10 | scan14
11 | scan15
12 | scan16
13 | scan17
14 | scan18
15 | scan19
16 | scan20
17 | scan22
18 | scan23
19 | scan24
20 | scan28
21 | scan32
22 | scan33
23 | scan35
24 | scan36
25 | scan37
26 | scan42
27 | scan43
28 | scan44
29 | scan46
30 | scan47
31 | scan48
32 | scan49
33 | scan50
34 | scan52
35 | scan53
36 | scan59
37 | scan60
38 | scan61
39 | scan62
40 | scan64
41 | scan65
42 | scan66
43 | scan67
44 | scan68
45 | scan69
46 | scan70
47 | scan71
48 | scan72
49 | scan74
50 | scan75
51 | scan76
52 | scan77
53 | scan84
54 | scan85
55 | scan86
56 | scan87
57 | scan88
58 | scan89
59 | scan90
60 | scan91
61 | scan92
62 | scan93
63 | scan94
64 | scan95
65 | scan96
66 | scan97
67 | scan98
68 | scan99
69 | scan100
70 | scan101
71 | scan102
72 | scan104
73 | scan105
74 | scan106
75 | scan107
76 | scan108
77 | scan109
78 | scan118
79 | scan119
80 | scan120
81 | scan121
82 | scan122
83 | scan123
84 | scan124
85 | scan125
86 | scan126
87 | scan127
88 | scan128
--------------------------------------------------------------------------------
/data/mvsnerf/dtu_val_all.txt:
--------------------------------------------------------------------------------
1 | scan1
2 | scan8
3 | scan21
4 | scan30
5 | scan31
6 | scan34
7 | scan38
8 | scan40
9 | scan41
10 | scan45
11 | scan55
12 | scan63
13 | scan82
14 | scan103
15 | scan110
16 | scan114
17 |
--------------------------------------------------------------------------------
/data/mvsnerf/pairs.th:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TQTQliu/GeFu/86443625c6ae69a531b81a2e672d3922d82f0e70/data/mvsnerf/pairs.th
--------------------------------------------------------------------------------
/lib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TQTQliu/GeFu/86443625c6ae69a531b81a2e672d3922d82f0e70/lib/__init__.py
--------------------------------------------------------------------------------
/lib/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .config import cfg, args
2 |
--------------------------------------------------------------------------------
/lib/config/config.py:
--------------------------------------------------------------------------------
1 | from .yacs import CfgNode as CN
2 | import argparse
3 | import os
4 | import numpy as np
5 | from . import yacs
6 |
7 |
8 | cfg = CN()
9 |
10 | os.environ['workspace'] = "."
11 | cfg.workspace = os.environ['workspace']
12 | print('Workspace: ', cfg.workspace)
13 |
14 |
15 | cfg.level = 32.
16 | cfg.resolution = 256
17 |
18 | cfg.vis_encoder = ''
19 | cfg.feat_vis_len = 8
20 |
21 | cfg.cache_data = False
22 | cfg.sample_keypoints_epoch = -1
23 |
24 | cfg.write_video = False
25 | cfg.interested_mask = False
26 | cfg.render_path = False
27 | cfg.render_emb = 0
28 | cfg.render_ixt = 0
29 | cfg.code_id = -1
30 | cfg.time_weight = 0.
31 | cfg.render_static = True
32 | cfg.pretrain_path = ''
33 | cfg.scene = 'test'
34 | cfg.last_view = False
35 | cfg.exp_hard = False
36 | cfg.pos_encoding_t = False
37 | cfg.render_time = False
38 | cfg.render_time_skip = [0, -1, 1]
39 | cfg.start_time = [2009, 8, 1]
40 | cfg.end_time = [2013, 12, 1]
41 | cfg.discrete_3views = False
42 | cfg.fps = 24
43 | cfg.dcat = False
44 | cfg.min_y = -100000000.
45 | cfg.time_discrete = -1
46 | cfg.render_day = False
47 | cfg.render_date = [2013, 1, 1]
48 | cfg.rand_t = -1.
49 | cfg.semantic_mask = False
50 | cfg.product_combine = False
51 | cfg.unisample = False
52 | cfg.render_emb_2 = -1
53 | cfg.render_num = 30
54 | cfg.render_ext = 0
55 | cfg.time_geo = False
56 | cfg.reg_beta = False
57 | cfg.fix_beta = False
58 | cfg.hard_lap = False
59 | cfg.render_octree = False
60 | cfg.render_mask = False
61 | cfg.environment_map = False
62 |
63 | cfg.save_result = False
64 | cfg.clear_result = False
65 | cfg.save_tag = 'default'
66 | # module
67 | cfg.train_dataset_module = 'lib.datasets.dtu.neus'
68 | cfg.test_dataset_module = 'lib.datasets.dtu.neus'
69 | cfg.val_dataset_module = 'lib.datasets.dtu.neus'
70 | cfg.network_module = 'lib.neworks.neus.neus'
71 | cfg.loss_module = 'lib.train.losses.neus'
72 | cfg.evaluator_module = 'lib.evaluators.neus'
73 |
74 | # experiment name
75 | cfg.exp_name = 'gitbranch_hello'
76 | cfg.exp_name_tag = ''
77 | cfg.pretrain = ''
78 |
79 | # network
80 | cfg.distributed = False
81 |
82 | # task
83 | cfg.task = 'hello'
84 |
85 | # gpus
86 | cfg.gpus = list(range(4))
87 | # if load the pretrained network
88 | cfg.resume = True
89 |
90 | # epoch
91 | cfg.ep_iter = -1
92 | cfg.save_ep = 1
93 | cfg.save_latest_ep = 1
94 | cfg.eval_ep = 1
95 | log_interval: 20
96 |
97 |
98 | cfg.task_arg = CN()
99 | cfg.task_arg.sample_more_on_mask = -1.
100 | cfg.task_arg.sample_on_mask = False
101 |
102 | # -----------------------------------------------------------------------------
103 | # train
104 | # -----------------------------------------------------------------------------
105 | cfg.train = CN()
106 | cfg.train.epoch = 10000
107 | cfg.train.num_workers = 8
108 | cfg.train.collator = 'default'
109 | cfg.train.batch_sampler = 'default'
110 | cfg.train.sampler_meta = CN({})
111 | cfg.train.shuffle = True
112 | cfg.train.eps = 1e-8
113 |
114 | # use adam as default
115 | cfg.train.optim = 'adam'
116 | cfg.train.lr = 5e-4
117 | cfg.train.weight_decay = 0.
118 | cfg.train.scheduler = CN({'type': 'multi_step', 'milestones': [80, 120, 200, 240], 'gamma': 0.5})
119 | cfg.train.batch_size = 4
120 |
121 | # test
122 | cfg.test = CN()
123 | cfg.test.batch_size = 1
124 | cfg.test.collator = 'default'
125 | cfg.test.epoch = -1
126 | cfg.test.batch_sampler = 'default'
127 | cfg.test.sampler_meta = CN({})
128 |
129 | # trained model
130 | cfg.trained_model_dir = os.path.join(os.environ['workspace'], 'trained_model')
131 | cfg.clean_tag = 'debug'
132 |
133 | # recorder
134 | cfg.record_dir = os.path.join(os.environ['workspace'], 'record')
135 |
136 | # result
137 | cfg.result_dir = os.path.join(os.environ['workspace'], 'result')
138 |
139 | # evaluation
140 | cfg.skip_eval = False
141 |
142 | cfg.fix_random = False
143 |
144 | def parse_cfg(cfg, args):
145 | if len(cfg.task) == 0:
146 | raise ValueError('task must be specified')
147 |
148 | # assign the gpus
149 | if -1 not in cfg.gpus:
150 | os.environ['CUDA_VISIBLE_DEVICES'] = ', '.join([str(gpu) for gpu in cfg.gpus])
151 |
152 | if 'bbox' in cfg:
153 | bbox = np.array(cfg.bbox).reshape((2, 3))
154 | center, half_size = np.mean(bbox, axis=0), (bbox[1]-bbox[0]).max().item() / 2.
155 | bbox = np.stack([center-half_size, center+half_size])
156 | cfg.bbox = bbox.reshape(6).tolist()
157 |
158 | if len(cfg.exp_name_tag) != 0:
159 | cfg.exp_name += ('_' + cfg.exp_name_tag)
160 | cfg.exp_name = cfg.exp_name.replace('gitbranch', os.popen('git describe --all').readline().strip()[6:])
161 | cfg.exp_name = cfg.exp_name.replace('gitcommit', os.popen('git describe --tags --always').readline().strip())
162 | print('EXP NAME: ', cfg.exp_name)
163 | cfg.trained_model_dir = os.path.join(cfg.trained_model_dir, cfg.task, cfg.exp_name)
164 | cfg.record_dir = os.path.join(cfg.record_dir, cfg.task, cfg.exp_name)
165 | cfg.result_dir = os.path.join(cfg.result_dir, cfg.task, cfg.exp_name, cfg.save_tag)
166 | cfg.local_rank = args.local_rank
167 | modules = [key for key in cfg if '_module' in key]
168 | for module in modules:
169 | cfg[module.replace('_module', '_path')] = cfg[module].replace('.', '/') + '.py'
170 |
171 | def make_cfg(args):
172 | def merge_cfg(cfg_file, cfg):
173 | with open(cfg_file, 'r') as f:
174 | current_cfg = yacs.load_cfg(f)
175 | if 'parent_cfg' in current_cfg.keys():
176 | cfg = merge_cfg(current_cfg.parent_cfg, cfg)
177 | cfg.merge_from_other_cfg(current_cfg)
178 | else:
179 | cfg.merge_from_other_cfg(current_cfg)
180 | print(cfg_file)
181 | return cfg
182 | cfg_ = merge_cfg(args.cfg_file, cfg)
183 | try:
184 | index = args.opts.index('other_opts')
185 | cfg_.merge_from_list(args.opts[:index])
186 | except:
187 | cfg_.merge_from_list(args.opts)
188 | parse_cfg(cfg_, args)
189 | return cfg_
190 |
191 |
192 | parser = argparse.ArgumentParser()
193 | parser.add_argument("--cfg_file", default="configs/gefu/dtu_pretrain.yaml", type=str)
194 | parser.add_argument('--test', action='store_true', dest='test', default=False)
195 | parser.add_argument("--type", type=str, default="")
196 | parser.add_argument('--det', type=str, default='')
197 | parser.add_argument('--local_rank', type=int, default=0)
198 | parser.add_argument("opts", default=None, nargs=argparse.REMAINDER)
199 | args = parser.parse_args()
200 | if len(args.type) > 0:
201 | cfg.task = "run"
202 | cfg = make_cfg(args)
203 |
204 |
--------------------------------------------------------------------------------
/lib/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .make_dataset import make_data_loader
2 |
--------------------------------------------------------------------------------
/lib/datasets/collate_batch.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data.dataloader import default_collate
2 | import torch
3 | import numpy as np
4 | from lib.config import cfg
5 |
6 | _collators = {}
7 |
8 | def make_collator(cfg, is_train):
9 | collator = cfg.train.collator if is_train else cfg.test.collator
10 | if collator in _collators:
11 | return _collators[collator]
12 | else:
13 | return default_collate
14 |
--------------------------------------------------------------------------------
/lib/datasets/dtu/gefu.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from lib.datasets import gefu_utils
4 | from lib.config import cfg
5 | import imageio
6 | import cv2
7 | import random
8 | from lib.config import cfg
9 | from lib.utils import data_utils
10 | import torch
11 | from lib.datasets.video_utils import *
12 |
13 |
14 | if cfg.fix_random:
15 | random.seed(0)
16 | np.random.seed(0)
17 |
18 | class Dataset:
19 | def __init__(self, **kwargs):
20 | super(Dataset, self).__init__()
21 | self.data_root = os.path.join(cfg.workspace, kwargs['data_root'])
22 | self.split = kwargs['split']
23 | if 'scene' in kwargs:
24 | self.scenes = [kwargs['scene']]
25 | else:
26 | self.scenes = []
27 | self.build_metas(kwargs['ann_file'])
28 | self.depth_ranges = [425., 905.]
29 |
30 | def build_metas(self, ann_file):
31 | scenes = [line.strip() for line in open(ann_file).readlines()]
32 | dtu_pairs = torch.load('data/mvsnerf/pairs.th')
33 | self.scene_infos = {}
34 | self.metas = []
35 | if len(self.scenes) != 0:
36 | scenes = self.scenes
37 |
38 | for scene in scenes:
39 | scene_info = {'ixts': [], 'exts': [], 'dpt_paths': [], 'img_paths': []}
40 | for i in range(49):
41 | cam_path = os.path.join(self.data_root, 'Cameras/train/{:08d}_cam.txt'.format(i))
42 | ixt, ext, _ = data_utils.read_cam_file(cam_path)
43 | ext[:3, 3] = ext[:3, 3]
44 | ixt[:2] = ixt[:2] * 4
45 | dpt_path = os.path.join(self.data_root, 'Depths_raw/{}/depth_map_{:04d}.pfm'.format(scene, i))
46 | img_path = os.path.join(self.data_root, 'Rectified/{}_train/rect_{:03d}_3_r5000.png'.format(scene, i+1))
47 | scene_info['ixts'].append(ixt.astype(np.float32))
48 | scene_info['exts'].append(ext.astype(np.float32))
49 | scene_info['dpt_paths'].append(dpt_path)
50 | scene_info['img_paths'].append(img_path)
51 |
52 | if self.split == 'train' and len(self.scenes) != 1:
53 | train_ids = np.arange(49).tolist()
54 | test_ids = np.arange(49).tolist()
55 | elif self.split == 'train' and len(self.scenes) == 1:
56 | train_ids = dtu_pairs['dtu_train']
57 | test_ids = dtu_pairs['dtu_train']
58 | else:
59 | train_ids = dtu_pairs['dtu_train']
60 | test_ids = dtu_pairs['dtu_val']
61 | scene_info.update({'train_ids': train_ids, 'test_ids': test_ids})
62 | self.scene_infos[scene] = scene_info
63 |
64 | cam_points = np.array([np.linalg.inv(scene_info['exts'][i])[:3, 3] for i in train_ids])
65 | for tar_view in test_ids:
66 | cam_point = np.linalg.inv(scene_info['exts'][tar_view])[:3, 3]
67 | distance = np.linalg.norm(cam_points - cam_point[None], axis=-1)
68 | argsorts = distance.argsort()
69 | argsorts = argsorts[1:] if tar_view in train_ids else argsorts
70 | input_views_num = cfg.gefu.train_input_views[1] + 1 if self.split == 'train' else cfg.gefu.test_input_views
71 | src_views = [train_ids[i] for i in argsorts[:input_views_num]]
72 | self.metas += [(scene, tar_view, src_views)]
73 |
74 | def __getitem__(self, index_meta):
75 | index, input_views_num = index_meta
76 | scene, tar_view, src_views = self.metas[index]
77 | # src_views = [tar_view] + src_views # only used for depth evaluation under reference view
78 | if self.split == 'train':
79 | if random.random() < 0.1:
80 | src_views = src_views + [tar_view]
81 | src_views = random.sample(src_views[:input_views_num+1], input_views_num)
82 | scene_info = self.scene_infos[scene]
83 |
84 | tar_img = np.array(imageio.imread(scene_info['img_paths'][tar_view])) / 255.
85 | H, W = tar_img.shape[:2]
86 | tar_ext, tar_ixt = scene_info['exts'][tar_view], scene_info['ixts'][tar_view]
87 | if self.split != 'train': # only used for evaluation
88 | tar_dpt = data_utils.read_pfm(scene_info['dpt_paths'][tar_view])[0].astype(np.float32)
89 | tar_dpt = cv2.resize(tar_dpt, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_NEAREST)
90 | tar_dpt = tar_dpt[44:556, 80:720]
91 | tar_mask = (tar_dpt > 0.).astype(np.uint8)
92 | else:
93 | # tar_dpt = np.ones_like(tar_img)
94 | # tar_mask = np.ones_like(tar_img)
95 | tar_dpt = data_utils.read_pfm(scene_info['dpt_paths'][tar_view])[0].astype(np.float32)
96 | tar_dpt = cv2.resize(tar_dpt, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_NEAREST)
97 | tar_dpt = tar_dpt[44:556, 80:720]
98 | tar_mask = (tar_dpt > 0.).astype(np.uint8)
99 |
100 | src_inps, src_exts, src_ixts = self.read_src(scene_info, src_views)
101 |
102 | ret = {'src_inps': src_inps,
103 | 'src_exts': src_exts,
104 | 'src_ixts': src_ixts}
105 | ret.update({'tar_ext': tar_ext,
106 | 'tar_ixt': tar_ixt})
107 | if self.split != 'train':
108 | ret.update({'tar_img': tar_img,
109 | 'tar_dpt': tar_dpt,
110 | 'tar_mask': tar_mask})
111 | ret.update({'near_far': np.array(self.depth_ranges).astype(np.float32)})
112 | ret.update({'meta': {'scene': scene, 'tar_view': tar_view, 'frame_id': 0}})
113 |
114 | for i in range(cfg.gefu.cas_config.num):
115 | rays, rgb, msk = gefu_utils.build_rays(tar_img, tar_ext, tar_ixt, tar_mask, i, self.split)
116 | s = cfg.gefu.cas_config.volume_scale[i]
117 | if self.split != 'train': # evaluation
118 | tar_dpt_i = cv2.resize(tar_dpt, None, fx=s, fy=s, interpolation=cv2.INTER_NEAREST)
119 | ret.update({f'tar_dpt_{i}': tar_dpt_i.astype(np.float32)})
120 | ret.update({f'rays_{i}': rays, f'rgb_{i}': rgb.astype(np.float32), f'msk_{i}': msk})
121 | ret['meta'].update({f'h_{i}': H, f'w_{i}': W})
122 |
123 | if cfg.save_video:
124 | rendering_video_meta = []
125 | render_path_mode = 'interpolate'
126 | poses_paths = self.get_video_rendering_path(ref_poses=src_exts, mode=render_path_mode, near_far=None, train_c2w_all=None, n_frames=60)
127 | for pose in poses_paths[0]:
128 | rendering_meta = {
129 | 'tar_ext': pose
130 | }
131 | for i in range(cfg.gefu.cas_config.num):
132 | tar_ext[:3] = pose
133 | rays, _, _ = gefu_utils.build_rays(tar_img, tar_ext, tar_ixt, tar_mask, i, self.split)
134 | rendering_meta.update({f'rays_{i}': rays})
135 | rendering_video_meta.append(rendering_meta)
136 | ret['rendering_video_meta'] = rendering_video_meta
137 | return ret
138 |
139 |
140 | def get_video_rendering_path(self, ref_poses, mode, near_far, train_c2w_all, n_frames=60):
141 | poses_paths = []
142 | ref_poses = ref_poses[None]
143 | for batch_idx, cur_src_poses in enumerate(ref_poses):
144 | if mode == 'interpolate':
145 | # convert to c2ws
146 | pose_square = torch.eye(4).unsqueeze(0).repeat(cur_src_poses.shape[0], 1, 1)
147 | cur_src_poses = torch.from_numpy(cur_src_poses)
148 | pose_square[:, :3, :] = cur_src_poses[:,:3]
149 | cur_c2ws = pose_square.double().inverse()[:, :3, :].to(torch.float32).cpu().detach().numpy()
150 | cur_path = get_interpolate_render_path(cur_c2ws, n_frames)
151 | elif mode == 'spiral':
152 | cur_c2ws_all = train_c2w_all
153 | cur_near_far = near_far.tolist()
154 | rads_scale = 0.3
155 | cur_path = get_spiral_render_path(cur_c2ws_all, cur_near_far, rads_scale=rads_scale, N_views=n_frames)
156 | else:
157 | raise Exception(f'Unknown video rendering path mode {mode}')
158 |
159 | # convert back to extrinsics tensor
160 | cur_w2cs = torch.tensor(cur_path).inverse()[:, :3].to(torch.float32)
161 | poses_paths.append(cur_w2cs)
162 |
163 | poses_paths = torch.stack(poses_paths, dim=0)
164 | return poses_paths
165 |
166 | def read_src(self, scene_info, src_views):
167 | inps, exts, ixts = [], [], []
168 | for src_view in src_views:
169 | inps.append((np.array(imageio.imread(scene_info['img_paths'][src_view])) / 255.) * 2. - 1.)
170 | exts.append(scene_info['exts'][src_view])
171 | ixts.append(scene_info['ixts'][src_view])
172 | return np.stack(inps).transpose((0, 3, 1, 2)).astype(np.float32), np.stack(exts), np.stack(ixts)
173 |
174 | def __len__(self):
175 | return len(self.metas)
176 |
177 |
--------------------------------------------------------------------------------
/lib/datasets/gefu_utils.py:
--------------------------------------------------------------------------------
1 | from lib.config import cfg
2 | import cv2
3 | import numpy as np
4 |
5 | def sample_patch(num_patch, patch_size, H, W, msk_sample):
6 | half_patch_size = patch_size // 2
7 | if msk_sample.sum() > 0:
8 | num_fg_patch = num_patch
9 | non_zero = msk_sample.nonzero()
10 | permutation = np.random.permutation(msk_sample.sum())[:num_fg_patch].astype(np.int32)
11 | X_, Y_ = non_zero[1][permutation], non_zero[0][permutation]
12 | X_ = np.clip(X_, half_patch_size, W-half_patch_size)
13 | Y_ = np.clip(Y_, half_patch_size, H-half_patch_size)
14 | else:
15 | num_fg_patch = 0
16 | num_patch = num_patch - num_fg_patch
17 | X = np.random.randint(low=half_patch_size, high=W-half_patch_size, size=num_patch)
18 | Y = np.random.randint(low=half_patch_size, high=H-half_patch_size, size=num_patch)
19 | if num_fg_patch > 0:
20 | X = np.concatenate([X, X_]).astype(np.int32)
21 | Y = np.concatenate([Y, Y_]).astype(np.int32)
22 | grid = np.meshgrid(np.arange(patch_size)-half_patch_size, np.arange(patch_size)-half_patch_size)
23 | return np.concatenate([grid[0].reshape(-1) + x for x in X]), np.concatenate([grid[1].reshape(-1) + y for y in Y])
24 |
25 | def build_rays(tar_img, tar_ext, tar_ixt, tar_msk, level, split):
26 | scale = cfg.gefu.cas_config.render_scale[level]
27 | if scale != 1.:
28 | tar_img = cv2.resize(tar_img, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
29 | tar_msk = cv2.resize(tar_msk, None, fx=scale, fy=scale, interpolation=cv2.INTER_NEAREST)
30 | tar_ixt = tar_ixt.copy()
31 | tar_ixt[:2] *= scale
32 | H, W = tar_img.shape[:2]
33 | c2w = np.linalg.inv(tar_ext)
34 | if split == 'train' and not cfg.gefu.cas_config.train_img[level]:
35 | if cfg.gefu.sample_on_mask: # 313
36 | msk_sample = tar_msk
37 | num_fg_rays = int(min(cfg.gefu.cas_config.num_rays[level]*0.75, tar_msk.sum()*0.95))
38 | non_zero = msk_sample.nonzero()
39 | permutation = np.random.permutation(tar_msk.sum())[:num_fg_rays].astype(np.int32)
40 | X_, Y_ = non_zero[1][permutation], non_zero[0][permutation]
41 | else:
42 | num_fg_rays = 0
43 | msk_sample = np.zeros_like(tar_msk)
44 | num_rays = cfg.gefu.cas_config.num_rays[level] - num_fg_rays
45 | X = np.random.randint(low=0, high=W, size=num_rays)
46 | Y = np.random.randint(low=0, high=H, size=num_rays)
47 | if num_fg_rays > 0:
48 | X = np.concatenate([X, X_]).astype(np.int32)
49 | Y = np.concatenate([Y, Y_]).astype(np.int32)
50 | if cfg.gefu.cas_config.num_patchs[level] > 0:
51 | X_, Y_ = sample_patch(cfg.gefu.cas_config.num_patchs[level], cfg.gefu.cas_config.patch_size[level], H, W, msk_sample)
52 | X = np.concatenate([X, X_]).astype(np.int32)
53 | Y = np.concatenate([Y, Y_]).astype(np.int32)
54 | num_rays = len(X)
55 | rays_o = c2w[:3, 3][None].repeat(num_rays, 0)
56 | XYZ = np.concatenate((X[:, None], Y[:, None], np.ones_like(X[:, None])), axis=-1)
57 | XYZ = XYZ @ (np.linalg.inv(tar_ixt).T @ c2w[:3, :3].T)
58 | rays = np.concatenate((rays_o, XYZ, X[..., None], Y[..., None]), axis=-1)
59 | rgb = tar_img[Y, X]
60 | msk = tar_msk[Y, X]
61 | else:
62 | rays_o = c2w[:3, 3][None, None]
63 | X, Y = np.meshgrid(np.arange(W), np.arange(H))
64 | XYZ = np.concatenate((X[:, :, None], Y[:, :, None], np.ones_like(X[:, :, None])), axis=-1)
65 | XYZ = XYZ @ (np.linalg.inv(tar_ixt).T @ c2w[:3, :3].T)
66 | rays_o = rays_o.repeat(H, axis=0)
67 | rays_o = rays_o.repeat(W, axis=1)
68 | rays = np.concatenate((rays_o, XYZ, X[..., None], Y[..., None]), axis=-1)
69 | rgb = tar_img
70 | msk = tar_msk
71 | return rays.astype(np.float32).reshape(-1, 8), rgb.reshape(-1, 3), msk.reshape(-1)
72 |
73 |
74 |
--------------------------------------------------------------------------------
/lib/datasets/llff/gefu.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from lib.config import cfg
4 | import imageio
5 | import cv2
6 | import random
7 | from lib.config import cfg
8 | import torch
9 | from lib.datasets import gefu_utils
10 | from lib.datasets.video_utils import *
11 |
12 |
13 | class Dataset:
14 | def __init__(self, **kwargs):
15 | super(Dataset, self).__init__()
16 | self.data_root = os.path.join(cfg.workspace, kwargs['data_root'])
17 | self.split = kwargs['split']
18 | self.input_h_w = kwargs['input_h_w']
19 | if 'scene' in kwargs:
20 | self.scenes = [kwargs['scene']]
21 | else:
22 | self.scenes = []
23 | self.build_metas()
24 |
25 | def build_metas(self):
26 | if len(self.scenes) == 0:
27 | scenes = ['fern', 'flower', 'fortress', 'horns', 'leaves', 'orchids', 'room', 'trex']
28 | else:
29 | scenes = self.scenes
30 | self.scene_infos = {}
31 | self.metas = []
32 | self.c2ws_all = {}
33 | pairs = torch.load('data/mvsnerf/pairs.th')
34 | for scene in scenes:
35 |
36 | pose_bounds = np.load(os.path.join(self.data_root, scene, 'poses_bounds.npy')) # c2w, -u, r, -t
37 | poses = pose_bounds[:, :15].reshape((-1, 3, 5))
38 | c2ws = np.eye(4)[None].repeat(len(poses), 0)
39 | c2ws[:, :3, 0], c2ws[:, :3, 1], c2ws[:, :3, 2], c2ws[:, :3, 3] = poses[:, :3, 1], poses[:, :3, 0], -poses[:, :3, 2], poses[:, :3, 3]
40 | ixts = np.eye(3)[None].repeat(len(poses), 0)
41 | ixts[:, 0, 0], ixts[:, 1, 1] = poses[:, 2, 4], poses[:, 2, 4]
42 | ixts[:, 0, 2], ixts[:, 1, 2] = poses[:, 1, 4]/2., poses[:, 0, 4]/2.
43 | ixts[:, :2] *= 0.25
44 |
45 | img_paths = sorted([item for item in os.listdir(os.path.join(self.data_root, scene, 'images_4')) if '.png' in item])
46 | depth_ranges = pose_bounds[:, -2:]
47 | scene_info = {'ixts': ixts.astype(np.float32), 'c2ws': c2ws.astype(np.float32), 'image_names': img_paths, 'depth_ranges': depth_ranges.astype(np.float32)}
48 | scene_info['scene_name'] = scene
49 | self.scene_infos[scene] = scene_info
50 |
51 | train_ids = pairs[f'{scene}_train']
52 | if self.split == 'train':
53 | render_ids = train_ids
54 | else:
55 | render_ids = pairs[f'{scene}_val']
56 |
57 | c2ws = np.stack(c2ws)
58 | self.c2ws_all[scene] = c2ws
59 | c2ws = c2ws[train_ids]
60 | for i in render_ids:
61 | c2w = scene_info['c2ws'][i]
62 | distance = np.linalg.norm((c2w[:3, 3][None] - c2ws[:, :3, 3]), axis=-1)
63 | argsorts = distance.argsort()
64 | argsorts = argsorts[1:] if i in train_ids else argsorts
65 | if self.split == 'train':
66 | src_views = [train_ids[i] for i in argsorts[:cfg.gefu.train_input_views[1]+1]]
67 | else:
68 | src_views = [train_ids[i] for i in argsorts[:cfg.gefu.test_input_views]]
69 | self.metas += [(scene, i, src_views)]
70 |
71 | def get_video_rendering_path(self, ref_poses, mode, near_far, train_c2w_all, n_frames=60):
72 | poses_paths = []
73 | ref_poses = ref_poses[None]
74 | for batch_idx, cur_src_poses in enumerate(ref_poses):
75 | if mode == 'interpolate':
76 | # convert to c2ws
77 | pose_square = torch.eye(4).unsqueeze(0).repeat(cur_src_poses.shape[0], 1, 1)
78 | cur_src_poses = torch.from_numpy(cur_src_poses)
79 | pose_square[:, :3, :] = cur_src_poses[:,:3]
80 | cur_c2ws = pose_square.double().inverse()[:, :3, :].to(torch.float32).cpu().detach().numpy()
81 | cur_path = get_interpolate_render_path(cur_c2ws, n_frames)
82 | elif mode == 'spiral':
83 | cur_c2ws_all = train_c2w_all
84 | cur_near_far = near_far.tolist()
85 | rads_scale = 1
86 | cur_path = get_spiral_render_path(cur_c2ws_all, cur_near_far, rads_scale=rads_scale, N_views=n_frames)
87 | else:
88 | raise Exception(f'Unknown video rendering path mode {mode}')
89 |
90 | # convert back to extrinsics tensor
91 | cur_w2cs = torch.tensor(cur_path).inverse()[:, :3].to(torch.float32)
92 | poses_paths.append(cur_w2cs)
93 |
94 | poses_paths = torch.stack(poses_paths, dim=0)
95 | return poses_paths
96 |
97 | def __getitem__(self, index_meta):
98 | index, input_views_num = index_meta
99 | scene, tar_view, src_views = self.metas[index]
100 | if self.split == 'train':
101 | if np.random.random() < 0.1:
102 | src_views = src_views + [tar_view]
103 | src_views = random.sample(src_views, input_views_num)
104 | scene_info = self.scene_infos[scene]
105 | tar_img, tar_mask, tar_ext, tar_ixt = self.read_tar(scene_info, tar_view)
106 | src_inps, src_exts, src_ixts = self.read_src(scene_info, src_views)
107 |
108 | ret = {'src_inps': src_inps.transpose(0, 3, 1, 2),
109 | 'src_exts': src_exts,
110 | 'src_ixts': src_ixts}
111 | ret.update({'tar_ext': tar_ext,
112 | 'tar_ixt': tar_ixt})
113 | if self.split != 'train':
114 | ret.update({'tar_img': tar_img,
115 | 'tar_mask': tar_mask})
116 |
117 | H, W = tar_img.shape[:2]
118 | depth_ranges = np.array(scene_info['depth_ranges'])
119 | near_far = np.array([depth_ranges[:, 0].min().item(), depth_ranges[:, 1].max().item()]).astype(np.float32)
120 | # near_far = scene_info['depth_ranges'][tar_view]
121 | ret.update({'near_far': np.array(near_far).astype(np.float32)})
122 | ret.update({'meta': {'scene': scene, 'tar_view': tar_view, 'frame_id': 0}})
123 |
124 | for i in range(cfg.gefu.cas_config.num):
125 | rays, rgb, msk = gefu_utils.build_rays(tar_img, tar_ext, tar_ixt, tar_mask, i, self.split)
126 | ret.update({f'rays_{i}': rays, f'rgb_{i}': rgb.astype(np.float32), f'msk_{i}': msk})
127 | s = cfg.gefu.cas_config.volume_scale[i]
128 | ret['meta'].update({f'h_{i}': int(H*s), f'w_{i}': int(W*s)})
129 | if cfg.save_video:
130 | rendering_video_meta = []
131 | render_path_mode = 'spiral'
132 | train_c2w_all = self.c2ws_all[scene][src_views]
133 | poses_paths = self.get_video_rendering_path(src_exts, render_path_mode, near_far, train_c2w_all, 60)
134 | for pose in poses_paths[0]:
135 | rendering_meta = {
136 | 'tar_ext': pose
137 | }
138 | for i in range(cfg.gefu.cas_config.num):
139 | tar_ext[:3] = pose
140 | rays, _, _ = gefu_utils.build_rays(tar_img, tar_ext, tar_ixt, tar_mask, i, self.split)
141 | rendering_meta.update({f'rays_{i}': rays})
142 | rendering_video_meta.append(rendering_meta)
143 | ret['rendering_video_meta'] = rendering_video_meta
144 | return ret
145 |
146 | def read_src(self, scene, src_views):
147 | src_ids = src_views
148 | ixts, exts, imgs = [], [], []
149 | for idx in src_ids:
150 | img, orig_size = self.read_image(scene, idx)
151 | imgs.append(((img/255.)*2-1).astype(np.float32))
152 | ixt, ext, _ = self.read_cam(scene, idx, orig_size)
153 | ixts.append(ixt)
154 | exts.append(ext)
155 | return np.stack(imgs), np.stack(exts), np.stack(ixts)
156 |
157 | def read_tar(self, scene, view_idx):
158 | img, orig_size = self.read_image(scene, view_idx)
159 | img = (img/255.).astype(np.float32)
160 | ixt, ext, _ = self.read_cam(scene, view_idx, orig_size)
161 | mask = np.ones_like(img[..., 0]).astype(np.uint8)
162 | return img, mask, ext, ixt
163 |
164 | def read_cam(self, scene, view_idx, orig_size):
165 | ext = scene['c2ws'][view_idx].astype(np.float32)
166 | ixt = scene['ixts'][view_idx].copy()
167 | ixt[0] *= self.input_h_w[1] / orig_size[0]
168 | ixt[1] *= self.input_h_w[0] / orig_size[1]
169 | return ixt, np.linalg.inv(ext), 1
170 |
171 | def read_image(self, scene, view_idx):
172 | image_path = os.path.join(self.data_root, scene['scene_name'], 'images_4', scene['image_names'][view_idx])
173 | img = (np.array(imageio.imread(image_path))).astype(np.float32)
174 | orig_size = img.shape[:2][::-1]
175 | img = cv2.resize(img, self.input_h_w[::-1], interpolation=cv2.INTER_AREA)
176 | return np.array(img), orig_size
177 |
178 | def __len__(self):
179 | return len(self.metas)
180 |
181 | def get_K_from_params(params):
182 | K = np.zeros((3, 3)).astype(np.float32)
183 | K[0][0], K[0][2], K[1][2] = params[:3]
184 | K[1][1] = K[0][0]
185 | K[2][2] = 1.
186 | return K
187 |
188 |
--------------------------------------------------------------------------------
/lib/datasets/make_dataset.py:
--------------------------------------------------------------------------------
1 | from . import samplers
2 | import torch
3 | import torch.utils.data
4 | import imp
5 | import os
6 | from .collate_batch import make_collator
7 | import numpy as np
8 | import time
9 | from lib.config.config import cfg
10 | from torch.utils.data import DataLoader, ConcatDataset
11 | import cv2
12 | cv2.setNumThreads(1)
13 |
14 |
15 | # torch.multiprocessing.set_sharing_strategy('file_system')
16 |
17 | def _dataset_factory(is_train, is_val):
18 | if is_val:
19 | module = cfg.val_dataset_module
20 | path = cfg.val_dataset_path
21 | elif is_train:
22 | module = cfg.train_dataset_module
23 | path = cfg.train_dataset_path
24 | else:
25 | module = cfg.test_dataset_module
26 | path = cfg.test_dataset_path
27 | dataset = imp.load_source(module, path).Dataset
28 | return dataset
29 |
30 |
31 | def make_dataset(cfg, is_train=True):
32 | if is_train:
33 | args = cfg.train_dataset
34 | module = cfg.train_dataset_module
35 | path = cfg.train_dataset_path
36 | else:
37 | args = cfg.test_dataset
38 | module = cfg.test_dataset_module
39 | path = cfg.test_dataset_path
40 | dataset = imp.load_source(module, path).Dataset
41 | dataset = dataset(**args)
42 | return dataset
43 |
44 |
45 | def make_data_sampler(dataset, shuffle, is_distributed):
46 | if is_distributed:
47 | return samplers.DistributedSampler(dataset, shuffle=shuffle)
48 | if shuffle:
49 | sampler = torch.utils.data.sampler.RandomSampler(dataset)
50 | else:
51 | sampler = torch.utils.data.sampler.SequentialSampler(dataset)
52 | return sampler
53 |
54 |
55 | def make_batch_data_sampler(cfg, sampler, batch_size, drop_last, max_iter,
56 | is_train):
57 | if is_train:
58 | batch_sampler = cfg.train.batch_sampler
59 | sampler_meta = cfg.train.sampler_meta
60 | else:
61 | batch_sampler = cfg.test.batch_sampler
62 | sampler_meta = cfg.test.sampler_meta
63 | if batch_sampler == 'default':
64 | batch_sampler = torch.utils.data.sampler.BatchSampler(
65 | sampler, batch_size, drop_last)
66 | elif batch_sampler == 'image_size':
67 | batch_sampler = samplers.ImageSizeBatchSampler(sampler, batch_size,
68 | drop_last, sampler_meta)
69 | elif batch_sampler == 'gefu':
70 | batch_sampler = samplers.GefuBatchSampler(sampler, batch_size, drop_last, sampler_meta)
71 | if max_iter != -1:
72 | batch_sampler = samplers.IterationBasedBatchSampler(
73 | batch_sampler, max_iter)
74 | return batch_sampler
75 |
76 |
77 | def worker_init_fn(worker_id):
78 | np.random.seed(worker_id + (int(round(time.time() * 1000) % (2**16))))
79 |
80 |
81 | def make_data_loader(cfg, is_train=True, is_distributed=False, max_iter=-1):
82 | if is_train:
83 | batch_size = cfg.train.batch_size
84 | # shuffle = True
85 | shuffle = cfg.train.shuffle
86 | drop_last = False
87 | else:
88 | batch_size = cfg.test.batch_size
89 | shuffle = True if is_distributed else False
90 | drop_last = False
91 |
92 | dataset = make_dataset(cfg, is_train)
93 | sampler = make_data_sampler(dataset, shuffle, is_distributed)
94 | batch_sampler = make_batch_data_sampler(cfg, sampler, batch_size,
95 | drop_last, max_iter, is_train)
96 | num_workers = cfg.train.num_workers
97 | collator = make_collator(cfg, is_train)
98 | data_loader = DataLoader(dataset,
99 | batch_sampler=batch_sampler,
100 | num_workers=num_workers,
101 | collate_fn=collator,
102 | worker_init_fn=worker_init_fn)
103 |
104 | return data_loader
105 |
--------------------------------------------------------------------------------
/lib/datasets/nerf/gefu.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from lib.config import cfg
4 | import imageio
5 | import cv2
6 | import random
7 | from lib.config import cfg
8 | import torch
9 | import json
10 | from lib.datasets import gefu_utils
11 | from lib.datasets.video_utils import *
12 |
13 | class Dataset:
14 | def __init__(self, **kwargs):
15 | super(Dataset, self).__init__()
16 | self.data_root = os.path.join(cfg.workspace, kwargs['data_root'])
17 | self.split = kwargs['split']
18 | self.input_h_w = kwargs['input_h_w']
19 | if 'scene' in kwargs:
20 | self.scenes = [kwargs['scene']]
21 | else:
22 | self.scenes = []
23 | self.build_metas()
24 |
25 | def build_metas(self):
26 | if len(self.scenes) == 0:
27 | scenes = ['chair', 'drums', 'ficus', 'hotdog', 'lego', 'materials', 'mic', 'ship']
28 | else:
29 | scenes = self.scenes
30 | self.scene_infos = {}
31 | self.metas = []
32 | pairs = torch.load('data/mvsnerf/pairs.th')
33 | for scene in scenes:
34 | json_info = json.load(open(os.path.join(self.data_root, scene,'transforms_train.json')))
35 | b2c = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
36 | scene_info = {'ixts': [], 'exts': [], 'img_paths': []}
37 | for idx in range(len(json_info['frames'])):
38 | c2w = np.array(json_info['frames'][idx]['transform_matrix'])
39 | c2w = c2w @ b2c
40 | ext = np.linalg.inv(c2w)
41 | ixt = np.eye(3)
42 | ixt[0][2], ixt[1][2] = 400., 400.
43 | focal = .5 * 800 / np.tan(.5 * json_info['camera_angle_x'])
44 | ixt[0][0], ixt[1][1] = focal, focal
45 | scene_info['ixts'].append(ixt.astype(np.float32))
46 | scene_info['exts'].append(ext.astype(np.float32))
47 | img_path = os.path.join(self.data_root, scene, 'train/r_{}.png'.format(idx))
48 | scene_info['img_paths'].append(img_path)
49 | self.scene_infos[scene] = scene_info
50 | train_ids, render_ids = pairs[f'{scene}_train'], pairs[f'{scene}_val']
51 | if self.split == 'train':
52 | render_ids = train_ids
53 | c2ws = np.stack([np.linalg.inv(scene_info['exts'][idx]) for idx in train_ids])
54 | for idx in render_ids:
55 | c2w = np.linalg.inv(scene_info['exts'][idx])
56 | distance = np.linalg.norm((c2w[:3, 3][None] - c2ws[:, :3, 3]), axis=-1)
57 |
58 | argsorts = distance.argsort()
59 | argsorts = argsorts[1:] if idx in train_ids else argsorts
60 |
61 | input_views_num = cfg.gefu.train_input_views[1] + 1 if self.split == 'train' else cfg.gefu.test_input_views
62 | src_views = [train_ids[i] for i in argsorts[:input_views_num]]
63 | self.metas += [(scene, idx, src_views)]
64 |
65 | def get_video_rendering_path(self, ref_poses, mode, near_far, train_c2w_all, n_frames=60):
66 | poses_paths = []
67 | ref_poses = ref_poses[None]
68 | for batch_idx, cur_src_poses in enumerate(ref_poses):
69 | if mode == 'interpolate':
70 | # convert to c2ws
71 | pose_square = torch.eye(4).unsqueeze(0).repeat(cur_src_poses.shape[0], 1, 1)
72 | cur_src_poses = torch.from_numpy(cur_src_poses)
73 | pose_square[:, :3, :] = cur_src_poses[:,:3]
74 | cur_c2ws = pose_square.double().inverse()[:, :3, :].to(torch.float32).cpu().detach().numpy()
75 | cur_path = get_interpolate_render_path(cur_c2ws, n_frames)
76 | elif mode == 'spiral':
77 | cur_c2ws_all = train_c2w_all
78 | cur_near_far = near_far.tolist()
79 | rads_scale = 0.3
80 | cur_path = get_spiral_render_path(cur_c2ws_all, cur_near_far, rads_scale=rads_scale, N_views=n_frames)
81 | else:
82 | raise Exception(f'Unknown video rendering path mode {mode}')
83 |
84 | # convert back to extrinsics tensor
85 | cur_w2cs = torch.tensor(cur_path).inverse()[:, :3].to(torch.float32)
86 | poses_paths.append(cur_w2cs)
87 |
88 | poses_paths = torch.stack(poses_paths, dim=0)
89 | return poses_paths
90 |
91 |
92 | def __getitem__(self, index_meta):
93 | index, input_views_num = index_meta
94 | scene, tar_view, src_views = self.metas[index]
95 | if self.split == 'train':
96 | if np.random.random() < 0.1:
97 | src_views = src_views + [tar_view]
98 | src_views = random.sample(src_views, input_views_num)
99 | scene_info = self.scene_infos[scene]
100 | scene_info['scene_name'] = scene
101 | tar_img, tar_ext, tar_ixt = self.read_tar(scene_info, tar_view)
102 | src_inps, src_exts, src_ixts = self.read_src(scene_info, src_views)
103 |
104 | ret = {'src_inps': src_inps.transpose(0, 3, 1, 2),
105 | 'src_exts': src_exts,
106 | 'src_ixts': src_ixts}
107 | tar_mask = np.ones_like(tar_img[..., 0]).astype(np.uint8)
108 | H, W = tar_img.shape[:2]
109 | ret.update({'tar_ext': tar_ext,
110 | 'tar_ixt': tar_ixt})
111 | if self.split != 'train':
112 | ret.update({'tar_img': tar_img,
113 | 'tar_mask': tar_mask})
114 | near_far = np.array([2.5, 5.5]).astype(np.float32)
115 | ret.update({'near_far': np.array(near_far).astype(np.float32)})
116 | ret.update({'meta': {'scene': scene, 'tar_view': tar_view, 'frame_id': 0}})
117 |
118 | for i in range(cfg.gefu.cas_config.num):
119 | rays, rgb, msk = gefu_utils.build_rays(tar_img, tar_ext, tar_ixt, tar_mask, i, self.split)
120 | ret.update({f'rays_{i}': rays, f'rgb_{i}': rgb.astype(np.float32), f'msk_{i}': msk})
121 | s = cfg.gefu.cas_config.volume_scale[i]
122 | ret['meta'].update({f'h_{i}': int(H*s), f'w_{i}': int(W*s)})
123 | if cfg.save_video:
124 | rendering_video_meta = []
125 | render_path_mode = 'interpolate'
126 | poses_paths = self.get_video_rendering_path(ref_poses=src_exts, mode=render_path_mode, near_far=None, train_c2w_all=None, n_frames=60)
127 | for pose in poses_paths[0]:
128 | rendering_meta = {
129 | 'tar_ext': pose
130 | }
131 | for i in range(cfg.gefu.cas_config.num):
132 | tar_ext[:3] = pose
133 | rays, _, _ = gefu_utils.build_rays(tar_img, tar_ext, tar_ixt, tar_mask, i, self.split)
134 | rendering_meta.update({f'rays_{i}': rays})
135 | rendering_video_meta.append(rendering_meta)
136 | ret['rendering_video_meta'] = rendering_video_meta
137 | return ret
138 |
139 | def read_src(self, scene, src_views):
140 | src_ids = src_views
141 | ixts, exts, imgs = [], [], []
142 | for idx in src_ids:
143 | img, orig_size = self.read_image(scene, idx)
144 | imgs.append((img*2-1).astype(np.float32))
145 | ixt, ext = self.read_cam(scene, idx, orig_size)
146 | ixts.append(ixt)
147 | exts.append(ext)
148 | return np.stack(imgs), np.stack(exts), np.stack(ixts)
149 |
150 | def read_tar(self, scene, view_idx):
151 | img, orig_size = self.read_image(scene, view_idx)
152 | ixt, ext = self.read_cam(scene, view_idx, orig_size)
153 | return img, ext, ixt
154 |
155 | def read_cam(self, scene, view_idx, orig_size):
156 | ext = scene['exts'][view_idx]
157 | ixt = scene['ixts'][view_idx].copy()
158 | ixt[0] *= self.input_h_w[1] / orig_size[0]
159 | ixt[1] *= self.input_h_w[0] / orig_size[1]
160 | return ixt, ext
161 |
162 | def read_image(self, scene, view_idx):
163 | img_path = scene['img_paths'][view_idx]
164 | img = (np.array(imageio.imread(img_path)) / 255.).astype(np.float32)
165 | img = (img[..., :3] * img[..., -1:] + (1 - img[..., -1:])).astype(np.float32)
166 | orig_size = img.shape[:2][::-1]
167 | img = cv2.resize(img, self.input_h_w[::-1], interpolation=cv2.INTER_AREA)
168 | return img, orig_size
169 |
170 | def __len__(self):
171 | return len(self.metas)
172 |
173 | def get_K_from_params(params):
174 | K = np.zeros((3, 3)).astype(np.float32)
175 | K[0][0], K[0][2], K[1][2] = params[:3]
176 | K[1][1] = K[0][0]
177 | K[2][2] = 1.
178 | return K
179 |
180 |
--------------------------------------------------------------------------------
/lib/datasets/samplers.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data.sampler import Sampler
2 | from torch.utils.data.sampler import BatchSampler
3 | import numpy as np
4 | import torch
5 | import math
6 | import torch.distributed as dist
7 | from lib.config import cfg
8 |
9 | class GefuBatchSampler(Sampler):
10 | def __init__(self, sampler, batch_size, drop_last, sampler_meta):
11 | self.sampler = sampler
12 | self.batch_size = batch_size
13 | self.drop_last = drop_last
14 | self.input_views = sampler_meta.input_views_num
15 | self.views_prob = sampler_meta.input_views_prob
16 | if cfg.fix_random:
17 | random.seed(0)
18 |
19 | def __iter__(self):
20 | batch = []
21 | input_views_num = np.random.choice(self.input_views, 1, p=self.views_prob)
22 | for idx in self.sampler:
23 | batch.append((idx, input_views_num.item()))
24 | if len(batch) == self.batch_size:
25 | input_views_num = np.random.choice(self.input_views, 1, p=self.views_prob)
26 | yield batch
27 | batch = []
28 | if len(batch) > 0 and not self.drop_last:
29 | yield batch
30 |
31 | def __len__(self):
32 | if self.drop_last:
33 | return len(self.sampler) // self.batch_size
34 | else:
35 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size
36 |
37 |
38 | class ImageSizeBatchSampler(Sampler):
39 | def __init__(self, sampler, batch_size, drop_last, sampler_meta):
40 | self.sampler = sampler
41 | self.batch_size = batch_size
42 | self.drop_last = drop_last
43 | self.strategy = sampler_meta.strategy
44 | self.hmin, self.wmin = sampler_meta.min_hw
45 | self.hmax, self.wmax = sampler_meta.max_hw
46 | self.divisor = 32
47 | if cfg.fix_random:
48 | np.random.seed(0)
49 |
50 | def generate_height_width(self):
51 | if self.strategy == 'origin':
52 | return -1, -1
53 | h = np.random.randint(self.hmin, self.hmax + 1)
54 | w = np.random.randint(self.wmin, self.wmax + 1)
55 | h = (h | (self.divisor - 1)) + 1
56 | w = (w | (self.divisor - 1)) + 1
57 | return h, w
58 |
59 | def __iter__(self):
60 | batch = []
61 | h, w = self.generate_height_width()
62 | for idx in self.sampler:
63 | batch.append((idx, h, w))
64 | if len(batch) == self.batch_size:
65 | h, w = self.generate_height_width()
66 | yield batch
67 | batch = []
68 | if len(batch) > 0 and not self.drop_last:
69 | yield batch
70 |
71 | def __len__(self):
72 | if self.drop_last:
73 | return len(self.sampler) // self.batch_size
74 | else:
75 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size
76 |
77 |
78 | class IterationBasedBatchSampler(BatchSampler):
79 | """
80 | Wraps a BatchSampler, resampling from it until
81 | a specified number of iterations have been sampled
82 | """
83 |
84 | def __init__(self, batch_sampler, num_iterations, start_iter=0):
85 | self.batch_sampler = batch_sampler
86 | self.sampler = self.batch_sampler.sampler
87 | self.num_iterations = num_iterations
88 | self.start_iter = start_iter
89 |
90 | def __iter__(self):
91 | iteration = self.start_iter
92 | while iteration <= self.num_iterations:
93 | for batch in self.batch_sampler:
94 | iteration += 1
95 | if iteration > self.num_iterations:
96 | break
97 | yield batch
98 |
99 | def __len__(self):
100 | return self.num_iterations
101 |
102 |
103 | class DistributedSampler(Sampler):
104 | """Sampler that restricts data loading to a subset of the dataset.
105 | It is especially useful in conjunction with
106 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
107 | process can pass a DistributedSampler instance as a DataLoader sampler,
108 | and load a subset of the original dataset that is exclusive to it.
109 | .. note::
110 | Dataset is assumed to be of constant size.
111 | Arguments:
112 | dataset: Dataset used for sampling.
113 | num_replicas (optional): Number of processes participating in
114 | distributed training.
115 | rank (optional): Rank of the current process within num_replicas.
116 | """
117 |
118 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
119 | if num_replicas is None:
120 | if not dist.is_available():
121 | raise RuntimeError("Requires distributed package to be available")
122 | num_replicas = dist.get_world_size()
123 | if rank is None:
124 | if not dist.is_available():
125 | raise RuntimeError("Requires distributed package to be available")
126 | rank = dist.get_rank()
127 | self.dataset = dataset
128 | self.num_replicas = num_replicas
129 | self.rank = rank
130 | self.epoch = 0
131 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
132 | self.total_size = self.num_samples * self.num_replicas
133 | self.shuffle = shuffle
134 |
135 | def __iter__(self):
136 | if self.shuffle:
137 | # deterministically shuffle based on epoch
138 | g = torch.Generator()
139 | g.manual_seed(self.epoch)
140 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
141 | else:
142 | indices = torch.arange(len(self.dataset)).tolist()
143 |
144 | # add extra samples to make it evenly divisible
145 | indices += indices[: (self.total_size - len(indices))]
146 | assert len(indices) == self.total_size
147 |
148 | # subsample
149 | offset = self.num_samples * self.rank
150 | indices = indices[offset:offset+self.num_samples]
151 | assert len(indices) == self.num_samples
152 |
153 | return iter(indices)
154 |
155 | def __len__(self):
156 | return self.num_samples
157 |
158 | def set_epoch(self, epoch):
159 | self.epoch = epoch
160 |
--------------------------------------------------------------------------------
/lib/datasets/video_utils.py:
--------------------------------------------------------------------------------
1 | from lib.config import cfg
2 | import cv2
3 | import numpy as np
4 | from scipy.spatial.transform import Rotation
5 |
6 |
7 | def get_interpolate_render_path(c2ws, N_views=30):
8 | N = len(c2ws)
9 | rotvec, positions = [], []
10 | rotvec_inteplat, positions_inteplat = [], []
11 | weight = np.linspace(1.0, .0, N_views//3, endpoint=False).reshape(-1, 1)
12 | for i in range(N):
13 | r = Rotation.from_matrix(c2ws[i, :3, :3])
14 | euler_ange = r.as_euler('xyz', degrees=True).reshape(1, 3)
15 | if i:
16 | mask = np.abs(euler_ange - rotvec[0]) > 180
17 | euler_ange[mask] += 360.0
18 | rotvec.append(euler_ange)
19 | positions.append(c2ws[i, :3, 3:].reshape(1, 3))
20 |
21 | if i:
22 | rotvec_inteplat.append(weight * rotvec[i - 1] + (1.0 - weight) * rotvec[i])
23 | positions_inteplat.append(weight * positions[i - 1] + (1.0 - weight) * positions[i])
24 |
25 | rotvec_inteplat.append(weight * rotvec[-1] + (1.0 - weight) * rotvec[0])
26 | positions_inteplat.append(weight * positions[-1] + (1.0 - weight) * positions[0])
27 |
28 | c2ws_render = []
29 | angles_inteplat, positions_inteplat = np.concatenate(rotvec_inteplat), np.concatenate(positions_inteplat)
30 | for rotvec, position in zip(angles_inteplat, positions_inteplat):
31 | c2w = np.eye(4)
32 | c2w[:3, :3] = Rotation.from_euler('xyz', rotvec, degrees=True).as_matrix()
33 | c2w[:3, 3:] = position.reshape(3, 1)
34 | c2ws_render.append(c2w.copy())
35 | c2ws_render = np.stack(c2ws_render)
36 | return c2ws_render
37 |
38 |
39 |
40 |
41 |
42 | def normalize(v):
43 | """Normalize a vector."""
44 | return v / np.linalg.norm(v)
45 |
46 |
47 | def average_poses(poses):
48 | """
49 | Calculate the average pose, which is then used to center all poses
50 | using @center_poses. Its computation is as follows:
51 | 1. Compute the center: the average of pose centers.
52 | 2. Compute the z axis: the normalized average z axis.
53 | 3. Compute axis y': the average y axis.
54 | 4. Compute x' = y' cross product z, then normalize it as the x axis.
55 | 5. Compute the y axis: z cross product x.
56 |
57 | Note that at step 3, we cannot directly use y' as y axis since it's
58 | not necessarily orthogonal to z axis. We need to pass from x to y.
59 | Inputs:
60 | poses: (N_images, 3, 4)
61 | Outputs:
62 | pose_avg: (3, 4) the average pose
63 | """
64 | # 1. Compute the center
65 | center = poses[..., 3].mean(0) # (3)
66 | # 2. Compute the z axis
67 | z = normalize(poses[..., 2].mean(0)) # (3)
68 | # 3. Compute axis y' (no need to normalize as it's not the final output)
69 | y_ = poses[..., 1].mean(0) # (3)
70 | # 4. Compute the x axis
71 | x = normalize(np.cross(y_, z)) # (3)
72 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
73 | y = np.cross(z, x) # (3)
74 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4)
75 | return pose_avg
76 |
77 | def center_poses(poses, blender2opencv):
78 | """
79 | Center the poses so that we can use NDC.
80 | See https://github.com/bmild/nerf/issues/34
81 | Inputs:
82 | poses: (N_images, 3, 4)
83 | Outputs:
84 | poses_centered: (N_images, 3, 4) the centered poses
85 | pose_avg: (3, 4) the average pose
86 | """
87 | pose_avg = average_poses(poses) # (3, 4)
88 | pose_avg_homo = np.eye(4)
89 | pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation
90 | # by simply adding 0, 0, 0, 1 as the last row
91 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4)
92 | poses_homo = np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate
93 |
94 | poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4)
95 | poses_centered = poses_centered @ blender2opencv
96 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4)
97 |
98 | return poses_centered
99 |
100 | # --------------------- Start: Render scriptes for LLFF dataset video generation --------------------
101 | def get_spiral_render_path(c2ws_all, near_far, rads_scale=0.5, N_views=120):
102 | # center pose
103 | c2w = poses_avg(c2ws_all)
104 |
105 | # Get average pose
106 | up = normalize(c2ws_all[:, :3, 1].sum(0))
107 |
108 | # Find a reasonable "focus depth" for this dataset
109 | close_depth, inf_depth = near_far
110 | # print(near_far)
111 | dt = .75
112 | mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth))
113 | focal = mean_dz
114 | # print(focal)
115 | # Get radii for spiral path
116 | shrink_factor = .8
117 | zdelta = close_depth * .2
118 | tt = c2ws_all[:, :3, 3] - c2w[:3, 3][None]
119 | rads = np.percentile(np.abs(tt), 70, 0)*rads_scale
120 | render_poses = render_path_spiral(c2w, up, rads, focal, zdelta, zrate=.5, N=N_views)
121 | return np.stack(render_poses)
122 |
123 | def poses_avg(poses):
124 | center = poses[:, :3, 3].mean(0)
125 | vec2 = normalize(poses[:, :3, 2].sum(0))
126 | up = poses[:, :3, 1].sum(0)
127 | c2w = viewmatrix(vec2, up, center)
128 |
129 | return c2w
130 |
131 |
132 | def normalize(x):
133 | return x / np.linalg.norm(x, axis=-1, keepdims=True)
134 |
135 |
136 | def viewmatrix(z, up, pos):
137 | vec2 = normalize(z)
138 | vec1_avg = up
139 | vec0 = normalize(np.cross(vec1_avg, vec2))
140 | vec1 = normalize(np.cross(vec2, vec0))
141 | m = np.eye(4)
142 | m[:3] = np.stack([vec0, vec1, vec2, pos], 1)
143 | return m
144 |
145 |
146 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=120):
147 | render_poses = []
148 | rads = np.array(list(rads) + [1.])
149 |
150 | for theta in np.linspace(0., 2. * np.pi * N_rots, N+1)[:-1]:
151 | c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads)
152 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))
153 | render_poses.append(viewmatrix(z, up, c))
154 | return render_poses
155 | # --------------------- End: Render scriptes for LLFF dataset video generation --------------------
--------------------------------------------------------------------------------
/lib/evaluators/__init__.py:
--------------------------------------------------------------------------------
1 | from .make_evaluator import make_evaluator
2 |
--------------------------------------------------------------------------------
/lib/evaluators/gefu.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from lib.config import cfg
3 | import os
4 | import sys
5 | import imageio
6 | from lib.utils import img_utils
7 | from skimage.metrics import structural_similarity as ssim
8 | from skimage.metrics import peak_signal_noise_ratio as psnr
9 | import torch.nn.functional as F
10 | import torch
11 | import lpips
12 | import imageio
13 | from lib.utils import img_utils
14 | import cv2
15 | import torchvision.transforms as T
16 | from PIL import Image
17 | import matplotlib.pyplot as plt
18 |
19 |
20 | def write_cam(file, ixt, ext):
21 | f = open(file, "w")
22 | f.write('extrinsic\n')
23 | for i in range(0, 4):
24 | for j in range(0, 4):
25 | f.write(str(ext[i][j]) + ' ')
26 | f.write('\n')
27 | f.write('\n')
28 |
29 | f.write('intrinsic\n')
30 | for i in range(0, 3):
31 | for j in range(0, 3):
32 | f.write(str(ixt[i][j]) + ' ')
33 | f.write('\n')
34 |
35 | f.close()
36 |
37 |
38 | def save_pfm(filename, image, scale=1):
39 | file = open(filename, "wb")
40 | color = None
41 |
42 | image = np.flipud(image)
43 |
44 | if image.dtype.name != 'float32':
45 | raise Exception('Image dtype must be float32.')
46 |
47 | if len(image.shape) == 3 and image.shape[2] == 3: # color image
48 | color = True
49 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale
50 | color = False
51 | else:
52 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')
53 |
54 | file.write('PF\n'.encode('utf-8') if color else 'Pf\n'.encode('utf-8'))
55 | file.write('{} {}\n'.format(image.shape[1], image.shape[0]).encode('utf-8'))
56 |
57 | endian = image.dtype.byteorder
58 |
59 | if endian == '<' or endian == '=' and sys.byteorder == 'little':
60 | scale = -scale
61 |
62 | file.write(('%f\n' % scale).encode('utf-8'))
63 |
64 | image.tofile(file)
65 | file.close()
66 |
67 |
68 | def unpreprocess(data, shape=(1, 1, 3, 1, 1), render_scale=1.):
69 | device = data.device
70 | # mean = torch.tensor([-0.485/0.229, -0.456/0.224, -0.406 / 0.225]).view(*shape).to(device)
71 | # std = torch.tensor([1 / 0.229, 1 / 0.224, 1 / 0.225]).view(*shape).to(device)
72 | img = data * 0.5 + 0.5
73 | B, S, C, H, W = img.shape
74 | img = F.interpolate(img.reshape(B*S, C, H, W), scale_factor=render_scale, align_corners=True, mode='bilinear', recompute_scale_factor=True).reshape(B, S, C, int(H*render_scale), int(W*render_scale))
75 | return img
76 |
77 | def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET):
78 | if type(depth) is not np.ndarray:
79 | depth = depth.cpu().numpy()
80 |
81 | x = np.nan_to_num(depth) # change nan to 0
82 | if minmax is None:
83 | mi = np.min(x[x > 0]) # get minimum positive depth (ignore background)
84 | ma = np.max(x)
85 | else:
86 | mi, ma = minmax
87 |
88 | x = (x - mi) / (ma - mi + 1e-8) # normalize to 0~1
89 | x = (255 * x).astype(np.uint8)
90 | x_ = Image.fromarray(cv2.applyColorMap(x, cmap))
91 | x_ = T.ToTensor()(x_) # (3, H, W)
92 | return x_, [mi, ma]
93 |
94 |
95 |
96 | class Evaluator:
97 |
98 | def __init__(self,):
99 | self.psnrs = []
100 | self.ssims = []
101 | self.lpips = []
102 | self.scene_psnrs = {}
103 | self.scene_ssims = {}
104 | self.scene_lpips = {}
105 | self.loss_fn_vgg = lpips.LPIPS(net='vgg')
106 | self.loss_fn_vgg.cuda()
107 | if cfg.gefu.eval_depth:
108 | # Following the setup of MVSNeRF
109 | self.eval_depth_scenes = ['scan1', 'scan8', 'scan21', 'scan103', 'scan110']
110 | self.abs = []
111 | self.acc_2 = []
112 | self.acc_10 = []
113 | self.mvs_abs = []
114 | self.mvs_acc_2 = []
115 | self.mvs_acc_10 = []
116 | os.system('mkdir -p ' + cfg.result_dir)
117 |
118 | def evaluate(self, output, batch):
119 | B, S, _, H, W = batch['src_inps'].shape
120 | for i in range(cfg.gefu.cas_config.num):
121 | if not cfg.gefu.cas_config.render_if[i]:
122 | continue
123 | render_scale = cfg.gefu.cas_config.render_scale[i]
124 | h, w = int(H*render_scale), int(W*render_scale)
125 | pred_rgb = output[f'rgb_level{i}'].reshape(B, h, w, 3).detach().cpu().numpy()
126 | gt_rgb = batch[f'rgb_{i}'].reshape(B, h, w, 3).detach().cpu().numpy()
127 | masks = (batch[f'msk_{i}'].reshape(B, h, w).cpu().numpy() >= 1).astype(np.uint8)
128 | if i == cfg.gefu.cas_config.num-1:
129 | pred_rgb_b = output[f'rgb_b_level{i}'].reshape(B, h, w, 3).detach().cpu().numpy()
130 | pred_rgb_r = output[f'rgb_r_level{i}'].reshape(B, h, w, 3).detach().cpu().numpy()
131 |
132 | if cfg.gefu.eval_center:
133 | H_crop, W_crop = int(h*0.1), int(w*0.1)
134 | pred_rgb = pred_rgb[:, H_crop:-H_crop, W_crop:-W_crop]
135 | gt_rgb = gt_rgb[:, H_crop:-H_crop, W_crop:-W_crop]
136 | masks = masks[:, H_crop:-H_crop, W_crop:-W_crop]
137 | if i == cfg.gefu.cas_config.num-1:
138 | pred_rgb_b = pred_rgb_b[:, H_crop:-H_crop, W_crop:-W_crop]
139 | pred_rgb_r = pred_rgb_r[:, H_crop:-H_crop, W_crop:-W_crop]
140 | for b in range(B):
141 | if not batch['meta']['scene'][b]+f'_level{i}' in self.scene_psnrs:
142 | self.scene_psnrs[batch['meta']['scene'][b]+f'_level{i}'] = []
143 | self.scene_ssims[batch['meta']['scene'][b]+f'_level{i}'] = []
144 | self.scene_lpips[batch['meta']['scene'][b]+f'_level{i}'] = []
145 | if cfg.save_result and i == 1:
146 | img = img_utils.horizon_concate(gt_rgb[b], pred_rgb[b])
147 | img_path = os.path.join(cfg.result_dir, '{}_{}_{}.png'.format(batch['meta']['scene'][b], batch['meta']['tar_view'][b].item(), batch['meta']['frame_id'][b].item()))
148 | imageio.imwrite(img_path, (img*255.).astype(np.uint8))
149 |
150 | if cfg.save_ply:
151 | dataset_name = cfg.train_dataset_module.split('.')[-2]
152 | ply_dir = os.path.join(cfg.result_dir, 'pointclouds', dataset_name)
153 | os.makedirs(ply_dir, exist_ok = True)
154 | scan_dir = os.path.join(ply_dir, batch['meta']['scene'][0])
155 | os.makedirs(scan_dir, exist_ok = True)
156 | img_dir = os.path.join(scan_dir, 'images')
157 | os.makedirs(img_dir, exist_ok = True)
158 | img_path = os.path.join(img_dir, '{}_{}_{}.png'.format(batch['meta']['scene'][0], batch['meta']['tar_view'][0].item(), batch['meta']['frame_id'][0].item()))
159 | img = output[f'rgb_level{i}'].reshape(B, h, w, 3).detach().cpu().numpy()
160 | img = (img[b]*255).astype(np.uint8)
161 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
162 | cv2.imwrite(img_path, img)
163 |
164 | cam_dir = os.path.join(scan_dir, 'cam')
165 | os.makedirs(cam_dir, exist_ok = True)
166 | cam_path = os.path.join(cam_dir, '{}_{}_{}.txt'.format(batch['meta']['scene'][0], batch['meta']['tar_view'][0].item(), batch['meta']['frame_id'][0].item()))
167 |
168 | ixt = batch['tar_ixt'].detach().cpu().numpy()[0]
169 | ext = batch['tar_ext'].detach().cpu().numpy()[0]
170 | write_cam(cam_path, ixt, ext)
171 |
172 | nerf_depth = output['depth_level1'].cpu().numpy()[b].reshape((h, w))
173 | depth = nerf_depth
174 | depth_dir = os.path.join(scan_dir, 'depth')
175 | os.makedirs(depth_dir, exist_ok = True)
176 | depth_path = os.path.join(depth_dir, '{}_{}_{}.pfm'.format(batch['meta']['scene'][0], batch['meta']['tar_view'][0].item(), batch['meta']['frame_id'][0].item()))
177 | save_pfm(depth_path, depth)
178 |
179 | depth_minmax = [
180 | batch["near_far"].min().detach().cpu().numpy(),
181 | batch["near_far"].max().detach().cpu().numpy(),
182 | ]
183 | rendered_depth_vis, _ = visualize_depth(depth, depth_minmax)
184 | rendered_depth_vis = rendered_depth_vis.permute(1,2,0).detach().cpu().numpy()
185 | depth_vis_path = os.path.join(depth_dir, '{}_{}_{}.png'.format(batch['meta']['scene'][0], batch['meta']['tar_view'][0].item(), batch['meta']['frame_id'][0].item()))
186 | imageio.imwrite(depth_vis_path, (rendered_depth_vis*255.).astype(np.uint8))
187 |
188 |
189 | mask = masks[b] == 1
190 | gt_rgb[b][mask==False] = 0.
191 | pred_rgb[b][mask==False] = 0.
192 |
193 | psnr_item = psnr(gt_rgb[b][mask], pred_rgb[b][mask], data_range=1.)
194 | if i == cfg.gefu.cas_config.num-1:
195 | self.psnrs.append(psnr_item)
196 | self.scene_psnrs[batch['meta']['scene'][b]+f'_level{i}'].append(psnr_item)
197 |
198 | ssim_item = ssim(gt_rgb[b], pred_rgb[b], multichannel=True)
199 | if i == cfg.gefu.cas_config.num-1:
200 | self.ssims.append(ssim_item)
201 | self.scene_ssims[batch['meta']['scene'][b]+f'_level{i}'].append(ssim_item)
202 |
203 | if cfg.eval_lpips:
204 | gt, pred = torch.Tensor(gt_rgb[b])[None].permute(0, 3, 1, 2), torch.Tensor(pred_rgb[b])[None].permute(0, 3, 1, 2)
205 | gt, pred = (gt-0.5)*2., (pred-0.5)*2.
206 | lpips_item = self.loss_fn_vgg(gt.cuda(), pred.cuda()).item()
207 | if i == cfg.gefu.cas_config.num-1:
208 | self.lpips.append(lpips_item)
209 | self.scene_lpips[batch['meta']['scene'][b]+f'_level{i}'].append(lpips_item)
210 |
211 | if cfg.gefu.eval_depth and (i == cfg.gefu.cas_config.num - 1) and batch['meta']['scene'][b] in self.eval_depth_scenes:
212 | nerf_depth = output['depth_level1'].cpu().numpy()[b].reshape((h, w))
213 | mvs_depth = output['depth_mvs_level1'].cpu().numpy()[b]
214 | nerf_gt_depth = batch['tar_dpt'][b].cpu().numpy().reshape(*nerf_depth.shape)
215 | mvs_gt_depth = cv2.resize(nerf_gt_depth, mvs_depth.shape[::-1], interpolation=cv2.INTER_NEAREST)
216 | # nerf_mask = np.logical_and(nerf_gt_depth > 425., nerf_gt_depth < 905.)
217 | # mvs_mask = np.logical_and(mvs_gt_depth > 425., mvs_gt_depth < 905.)
218 | nerf_mask = nerf_gt_depth != 0.
219 | mvs_mask = mvs_gt_depth != 0.
220 | self.abs.append(np.abs(nerf_depth[nerf_mask] - nerf_gt_depth[nerf_mask]).mean())
221 | self.acc_2.append((np.abs(nerf_depth[nerf_mask] - nerf_gt_depth[nerf_mask]) < 2).mean())
222 | self.acc_10.append((np.abs(nerf_depth[nerf_mask] - nerf_gt_depth[nerf_mask]) < 10).mean())
223 | self.mvs_abs.append((np.abs(mvs_depth[mvs_mask] - mvs_gt_depth[mvs_mask])).mean())
224 | self.mvs_acc_2.append((np.abs(mvs_depth[mvs_mask] - mvs_gt_depth[mvs_mask]) < 2.).mean())
225 | self.mvs_acc_10.append((np.abs(mvs_depth[mvs_mask] - mvs_gt_depth[mvs_mask]) < 10.).mean())
226 |
227 | def summarize(self):
228 | ret = {}
229 | ret.update({'psnr': np.mean(self.psnrs)})
230 | ret.update({'ssim': np.mean(self.ssims)})
231 | if cfg.eval_lpips:
232 | ret.update({'lpips': np.mean(self.lpips)})
233 | print('='*30)
234 | for scene in self.scene_psnrs:
235 | if cfg.eval_lpips:
236 | print(scene.ljust(16), 'psnr: {:.2f} ssim: {:.3f} lpips:{:.3f}'.format(np.mean(self.scene_psnrs[scene]), np.mean(self.scene_ssims[scene]), np.mean(self.scene_lpips[scene])))
237 | else:
238 | print(scene.ljust(16), 'psnr: {:.2f} ssim: {:.3f} '.format(np.mean(self.scene_psnrs[scene]), np.mean(self.scene_ssims[scene])))
239 | print('='*30)
240 | print(ret)
241 | if cfg.gefu.eval_depth:
242 | depth_ret = {}
243 | keys = ['abs', 'acc_2', 'acc_10']
244 | for key in keys:
245 | depth_ret[key] = np.mean(getattr(self, key))
246 | setattr(self, key, [])
247 | print(depth_ret)
248 | keys = ['mvs_abs', 'mvs_acc_2', 'mvs_acc_10']
249 | depth_ret = {}
250 | for key in keys:
251 | depth_ret[key] = np.mean(getattr(self, key))
252 | setattr(self, key, [])
253 | print(depth_ret)
254 | self.psnrs = []
255 | self.ssims = []
256 | self.lpips = []
257 | self.scene_psnrs = {}
258 | self.scene_ssims = {}
259 | self.scene_lpips = {}
260 | if cfg.save_result:
261 | print('Save visualization results to: {}'.format(cfg.result_dir))
262 | return ret
263 |
--------------------------------------------------------------------------------
/lib/evaluators/make_evaluator.py:
--------------------------------------------------------------------------------
1 | import imp
2 | import os
3 |
4 | def _evaluator_factory(cfg):
5 | module = cfg.evaluator_module
6 | path = cfg.evaluator_path
7 | evaluator = imp.load_source(module, path).Evaluator()
8 | return evaluator
9 |
10 |
11 | def make_evaluator(cfg):
12 | if cfg.skip_eval:
13 | return None
14 | else:
15 | return _evaluator_factory(cfg)
16 |
--------------------------------------------------------------------------------
/lib/networks/__init__.py:
--------------------------------------------------------------------------------
1 | from .make_network import make_network
2 |
--------------------------------------------------------------------------------
/lib/networks/gefu/cost_reg_net.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .utils import *
3 |
4 | class CostRegNet(nn.Module):
5 | def __init__(self, in_channels, norm_act=nn.BatchNorm3d):
6 | super(CostRegNet, self).__init__()
7 | self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act)
8 |
9 | self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act)
10 | self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act)
11 |
12 | self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act)
13 | self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act)
14 |
15 | self.conv5 = ConvBnReLU3D(32, 64, stride=2, norm_act=norm_act)
16 | self.conv6 = ConvBnReLU3D(64, 64, norm_act=norm_act)
17 |
18 | self.conv7 = nn.Sequential(
19 | nn.ConvTranspose3d(64, 32, 3, padding=1, output_padding=1,
20 | stride=2, bias=False),
21 | norm_act(32))
22 |
23 | self.conv9 = nn.Sequential(
24 | nn.ConvTranspose3d(32, 16, 3, padding=1, output_padding=1,
25 | stride=2, bias=False),
26 | norm_act(16))
27 |
28 | self.conv11 = nn.Sequential(
29 | nn.ConvTranspose3d(16, 8, 3, padding=1, output_padding=1,
30 | stride=2, bias=False),
31 | norm_act(8))
32 | self.depth_conv = nn.Sequential(nn.Conv3d(8, 1, 3, padding=1, bias=False))
33 | self.feat_conv = nn.Sequential(nn.Conv3d(8, 8, 3, padding=1, bias=False))
34 |
35 | def forward(self, x):
36 | conv0 = self.conv0(x)
37 | conv2 = self.conv2(self.conv1(conv0))
38 | conv4 = self.conv4(self.conv3(conv2))
39 | x = self.conv6(self.conv5(conv4))
40 | x = conv4 + self.conv7(x)
41 | del conv4
42 | x = conv2 + self.conv9(x)
43 | del conv2
44 | x = conv0 + self.conv11(x)
45 | del conv0
46 | feat = self.feat_conv(x)
47 | depth = self.depth_conv(x)
48 | return feat, depth.squeeze(1)
49 |
50 |
51 | class MinCostRegNet(nn.Module):
52 | def __init__(self, in_channels, norm_act=nn.BatchNorm3d):
53 | super(MinCostRegNet, self).__init__()
54 | self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act)
55 |
56 | self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act)
57 | self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act)
58 |
59 | self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act)
60 | self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act)
61 |
62 | self.conv9 = nn.Sequential(
63 | nn.ConvTranspose3d(32, 16, 3, padding=1, output_padding=1,
64 | stride=2, bias=False),
65 | norm_act(16))
66 |
67 | self.conv11 = nn.Sequential(
68 | nn.ConvTranspose3d(16, 8, 3, padding=1, output_padding=1,
69 | stride=2, bias=False),
70 | norm_act(8))
71 |
72 | self.depth_conv = nn.Sequential(nn.Conv3d(8, 1, 3, padding=1, bias=False))
73 | self.feat_conv = nn.Sequential(nn.Conv3d(8, 8, 3, padding=1, bias=False))
74 |
75 | def forward(self, x):
76 | conv0 = self.conv0(x)
77 | conv2 = self.conv2(self.conv1(conv0))
78 | conv4 = self.conv4(self.conv3(conv2))
79 | x = conv4
80 | x = conv2 + self.conv9(x)
81 | del conv2
82 | x = conv0 + self.conv11(x)
83 | del conv0
84 | feat = self.feat_conv(x)
85 | depth = self.depth_conv(x)
86 | return feat, depth.squeeze(1)
87 |
88 |
89 |
90 | class SigCostRegNet(nn.Module):
91 | def __init__(self, in_channels, norm_act=nn.BatchNorm3d):
92 | super(SigCostRegNet, self).__init__()
93 | self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act)
94 |
95 | self.conv1 = ConvBnReLU3D(8, 16, stride=(1,2,2), pad=(1,1,1),norm_act=norm_act)
96 | self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act)
97 |
98 | self.conv3 = ConvBnReLU3D(16, 32, stride=(1,2,2), pad=(1,1,1), norm_act=norm_act)
99 | self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act)
100 |
101 | self.conv5 = ConvBnReLU3D(32, 64, stride=(1,2,2), pad=(1,1,1), norm_act=norm_act)
102 | self.conv6 = ConvBnReLU3D(64, 64, norm_act=norm_act)
103 |
104 | self.conv7 = nn.Sequential(
105 | nn.ConvTranspose3d(64, 32, 3, padding=(1,1,1), output_padding=(0,1,1),
106 | stride=(1,2,2), bias=False),
107 | norm_act(32))
108 |
109 | self.conv9 = nn.Sequential(
110 | nn.ConvTranspose3d(32, 16, 3, padding=(1,1,1), output_padding=(0,1,1),
111 | stride=(1,2,2), bias=False),
112 | norm_act(16))
113 |
114 | self.conv11 = nn.Sequential(
115 | nn.ConvTranspose3d(16, 8, 3, padding=(1,1,1), output_padding=(0,1,1),
116 | stride=(1,2,2), bias=False),
117 | norm_act(8))
118 |
119 | self.feat_conv = nn.Sequential(nn.Conv3d(8, in_channels, 3, padding=1, bias=False))
120 |
121 | def forward(self, x):
122 | conv0 = self.conv0(x)
123 | conv2 = self.conv2(self.conv1(conv0))
124 | conv4 = self.conv4(self.conv3(conv2))
125 | x = self.conv6(self.conv5(conv4))
126 | x = conv4 + self.conv7(x)
127 | del conv4
128 | x = conv2 + self.conv9(x)
129 | del conv2
130 | x = conv0 + self.conv11(x)
131 | del conv0
132 | feat = self.feat_conv(x)
133 | return feat
--------------------------------------------------------------------------------
/lib/networks/gefu/feature_net.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .utils import *
3 |
4 | class FeatureNet(nn.Module):
5 | def __init__(self, norm_act=nn.BatchNorm2d):
6 | super(FeatureNet, self).__init__()
7 | self.conv0 = nn.Sequential(
8 | ConvBnReLU(3, 8, 3, 1, 1, norm_act=norm_act),
9 | ConvBnReLU(8, 8, 3, 1, 1, norm_act=norm_act))
10 | self.conv1 = nn.Sequential(
11 | ConvBnReLU(8, 16, 5, 2, 2, norm_act=norm_act),
12 | ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act))
13 | self.conv2 = nn.Sequential(
14 | ConvBnReLU(16, 32, 5, 2, 2, norm_act=norm_act),
15 | ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act))
16 |
17 | self.toplayer = nn.Conv2d(32, 32, 1)
18 | self.lat1 = nn.Conv2d(16, 32, 1)
19 | self.lat0 = nn.Conv2d(8, 32, 1)
20 |
21 | self.smooth1 = nn.Conv2d(32, 16, 3, padding=1)
22 | self.smooth0 = nn.Conv2d(32, 8, 3, padding=1)
23 |
24 | def _upsample_add(self, x, y):
25 | return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + y
26 |
27 | def forward(self, x):
28 | conv0 = self.conv0(x)
29 | conv1 = self.conv1(conv0)
30 | conv2 = self.conv2(conv1)
31 | feat2 = self.toplayer(conv2)
32 | feat1 = self._upsample_add(feat2, self.lat1(conv1))
33 | feat0 = self._upsample_add(feat1, self.lat0(conv0))
34 | feat1 = self.smooth1(feat1)
35 | feat0 = self.smooth0(feat0)
36 | return feat2, feat1, feat0
37 |
38 |
39 | class CNNRender(nn.Module):
40 | def __init__(self, norm_act=nn.BatchNorm2d):
41 | super(CNNRender, self).__init__()
42 | self.conv0 = ConvBnReLU(3, 8, 3, 1, 1, norm_act=norm_act)
43 | self.conv1 = ConvBnReLU(8, 16, 5, 2, 2, norm_act=norm_act)
44 | self.conv2 = nn.Conv2d(8, 16, 1)
45 | self.conv3 = nn.Conv2d(16, 3, 1)
46 |
47 | def _upsample_add(self, x, y):
48 | return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + y
49 |
50 | def forward(self, x):
51 | conv0 = self.conv0(x)
52 | conv1 = self.conv1(conv0)
53 | conv2 = self._upsample_add(conv1, self.conv2(conv0))
54 | conv3 = self.conv3(conv2)
55 | return torch.clamp(conv3+x, 0., 1.)
56 |
57 |
58 | class AutoEncoder(nn.Module):
59 | def __init__(self, in_channels, hid_channels, out_channels, norm_act=nn.BatchNorm2d):
60 | super(AutoEncoder, self).__init__()
61 | self.conv0 = nn.Sequential(
62 | ConvBnReLU(in_channels,hid_channels, 3, 1, 1, norm_act=norm_act),
63 | ConvBnReLU(hid_channels, hid_channels, 3, 1, 1, norm_act=norm_act))
64 | self.conv1 = nn.Sequential(
65 | ConvBnReLU(hid_channels, hid_channels*2, 5, 2, 2, norm_act=norm_act),
66 | ConvBnReLU(hid_channels*2, hid_channels*2, 3, 1, 1, norm_act=norm_act))
67 | self.conv2 = nn.Sequential(
68 | ConvBnReLU(hid_channels*2, hid_channels*4, 5, 2, 2, norm_act=norm_act),
69 | ConvBnReLU(hid_channels*4, hid_channels*4, 3, 1, 1, norm_act=norm_act))
70 |
71 | self.toplayer = nn.Conv2d(hid_channels*4, out_channels, 1)
72 | self.lat1 = nn.Conv2d(hid_channels*2, out_channels, 1)
73 | self.lat0 = nn.Conv2d(hid_channels, out_channels, 1)
74 |
75 | self.out = nn.Conv2d(out_channels, out_channels, 3, padding=1)
76 |
77 | def _upsample_add(self, x, y):
78 | return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + y
79 |
80 | def forward(self, x):
81 | conv0 = self.conv0(x)
82 | conv1 = self.conv1(conv0)
83 | conv2 = self.conv2(conv1)
84 | feat2 = self.toplayer(conv2)
85 | feat1 = self._upsample_add(feat2, self.lat1(conv1))
86 | feat0 = self._upsample_add(feat1, self.lat0(conv0))
87 | out = self.out(feat0)
88 | return out
89 |
--------------------------------------------------------------------------------
/lib/networks/gefu/nerf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from lib.config import cfg
5 | from .cost_reg_net import SigCostRegNet
6 | class ScaledDotProductAttention(nn.Module):
7 | ''' Scaled Dot-Product Attention '''
8 |
9 | def __init__(self, temperature, attn_dropout=0.1):
10 | super().__init__()
11 | self.temperature = temperature
12 | # self.dropout = nn.Dropout(attn_dropout)
13 |
14 | def forward(self, q, k, v, mask=None):
15 |
16 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
17 |
18 | if mask is not None:
19 | attn = attn.masked_fill(mask == 0, -1e9)
20 | # attn = attn * mask
21 |
22 | attn = F.softmax(attn, dim=-1)
23 | # attn = self.dropout(F.softmax(attn, dim=-1))
24 | output = torch.matmul(attn, v)
25 |
26 | return output, attn
27 |
28 | class MultiHeadAttention(nn.Module):
29 | ''' Multi-Head Attention module '''
30 |
31 | def __init__(self, n_head, d_model, d_model_k, d_k, d_v, dropout=0.1):
32 | super().__init__()
33 |
34 | self.n_head = n_head
35 | self.d_k = d_k
36 | self.d_v = d_v
37 |
38 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
39 | self.w_ks = nn.Linear(d_model_k, n_head * d_k, bias=False)
40 | self.w_vs = nn.Linear(d_model_k, n_head * d_v, bias=False)
41 | self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
42 |
43 | self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
44 |
45 | # self.dropout = nn.Dropout(dropout)
46 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
47 |
48 | def forward(self, q, k, v, mask=None):
49 |
50 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
51 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
52 |
53 | residual = q
54 |
55 | # Pass through the pre-attention projection: b x lq x (n*dv)
56 | # Separate different heads: b x lq x n x dv
57 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
58 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
59 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
60 |
61 | # Transpose for attention dot product: b x n x lq x dv
62 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
63 |
64 | if mask is not None:
65 | mask = mask.unsqueeze(1) # For head axis broadcasting.
66 |
67 | q, attn = self.attention(q, k, v, mask=mask)
68 |
69 | # Transpose to move the head dimension back: b x lq x n x dv
70 | # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
71 | q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
72 | # q = self.dropout(self.fc(q))
73 | q = self.fc(q)
74 | q += residual
75 |
76 | return q, attn
77 |
78 | class NeRF(nn.Module):
79 | def __init__(self, hid_n=64, feat_ch=16+3):
80 | """
81 | """
82 | super(NeRF, self).__init__()
83 | self.hid_n = hid_n
84 | self.agg = Agg(feat_ch)
85 | self.lr0 = nn.Sequential(nn.Linear(8+16, hid_n),
86 | nn.ReLU())
87 | self.lrs = nn.ModuleList([
88 | nn.Sequential(nn.Linear(hid_n, hid_n), nn.ReLU()) for i in range(0)
89 | ])
90 | self.sigma = nn.Sequential(nn.Linear(hid_n, 1), nn.Softplus())
91 | self.color = nn.Sequential(
92 | nn.Linear(64+24+feat_ch+4, hid_n),
93 | nn.ReLU(),
94 | nn.Linear(hid_n, 1),
95 | nn.ReLU())
96 | self.lr0.apply(weights_init)
97 | self.lrs.apply(weights_init)
98 | self.sigma.apply(weights_init)
99 | self.color.apply(weights_init)
100 | self.regnet = SigCostRegNet(hid_n)
101 | self.attn = MultiHeadAttention(8,hid_n,feat_ch+4,4,4)
102 |
103 |
104 | def forward(self, vox_feat, img_feat_rgb_dir, size):
105 | H,W = size
106 | B, N_points, N_views = img_feat_rgb_dir.shape[:-1]
107 | S = img_feat_rgb_dir.shape[2]
108 | img_feat = self.agg(img_feat_rgb_dir)
109 | vox_img_feat = torch.cat((vox_feat, img_feat), dim=-1)
110 | x = self.lr0(vox_img_feat)
111 | x = x.reshape(B,H,W,-1,x.shape[-1])
112 | x = x.permute(0,4,3,1,2)
113 | x = self.regnet(x)
114 | x = x.permute(0,1,3,4,2).flatten(2).permute(0,2,1)
115 | q = x.squeeze(0).unsqueeze(1)
116 | k = img_feat_rgb_dir.squeeze(0)
117 | x,_ = self.attn(q,k,k)
118 | x = x.squeeze(1).unsqueeze(0)
119 | for i in range(len(self.lrs)):
120 | x = self.lrs[i](x)
121 | sigma = self.sigma(x)
122 | x = torch.cat((x, vox_img_feat), dim=-1)
123 | fea_transmit = x.clone()
124 | x = x.view(B, -1, 1, x.shape[-1]).repeat(1, 1, S, 1)
125 | x = torch.cat((x, img_feat_rgb_dir), dim=-1)
126 | color_weight = F.softmax(self.color(x), dim=-2)
127 | color = torch.sum((img_feat_rgb_dir[..., -7:-4] * color_weight), dim=-2)
128 | return torch.cat([color, sigma], dim=-1), fea_transmit
129 |
130 | class Agg(nn.Module):
131 | def __init__(self, feat_ch):
132 | """
133 | """
134 | super(Agg, self).__init__()
135 | self.feat_ch = feat_ch
136 | if cfg.gefu.viewdir_agg:
137 | self.view_fc = nn.Sequential(
138 | nn.Linear(4, feat_ch),
139 | nn.ReLU(),
140 | )
141 | self.view_fc.apply(weights_init)
142 | self.global_fc = nn.Sequential(
143 | nn.Linear(feat_ch*3, 32),
144 | nn.ReLU(),
145 | )
146 |
147 | self.agg_w_fc = nn.Sequential(
148 | nn.Linear(32, 1),
149 | nn.ReLU(),
150 | )
151 | self.fc = nn.Sequential(
152 | nn.Linear(32, 16),
153 | nn.ReLU(),
154 | )
155 | self.global_fc.apply(weights_init)
156 | self.agg_w_fc.apply(weights_init)
157 | self.fc.apply(weights_init)
158 |
159 | def forward(self, img_feat_rgb_dir):
160 | B, S = len(img_feat_rgb_dir), img_feat_rgb_dir.shape[-2]
161 | if cfg.gefu.viewdir_agg:
162 | view_feat = self.view_fc(img_feat_rgb_dir[..., -4:])
163 | img_feat_rgb = img_feat_rgb_dir[..., :-4] + view_feat
164 | else:
165 | img_feat_rgb = img_feat_rgb_dir[..., :-4]
166 |
167 | var_feat = torch.var(img_feat_rgb, dim=-2).view(B, -1, 1, self.feat_ch).repeat(1, 1, S, 1)
168 | avg_feat = torch.mean(img_feat_rgb, dim=-2).view(B, -1, 1, self.feat_ch).repeat(1, 1, S, 1)
169 |
170 | feat = torch.cat([img_feat_rgb, var_feat, avg_feat], dim=-1)
171 | global_feat = self.global_fc(feat)
172 | agg_w = F.softmax(self.agg_w_fc(global_feat), dim=-2)
173 | im_feat = (global_feat * agg_w).sum(dim=-2)
174 | return self.fc(im_feat)
175 |
176 | class MVSNeRF(nn.Module):
177 | def __init__(self, hid_n=64, feat_ch=16+3):
178 | """
179 | """
180 | super(MVSNeRF, self).__init__()
181 | self.hid_n = hid_n
182 | self.lr0 = nn.Sequential(nn.Linear(8+feat_ch*3, hid_n),
183 | nn.ReLU())
184 | self.lrs = nn.ModuleList([
185 | nn.Sequential(nn.Linear(hid_n, hid_n), nn.ReLU()) for i in range(0)
186 | ])
187 | self.sigma = nn.Sequential(nn.Linear(hid_n, 1), nn.Softplus())
188 | self.color = nn.Sequential(
189 | nn.Linear(hid_n, hid_n),
190 | nn.ReLU(),
191 | nn.Linear(hid_n, 3))
192 | self.lr0.apply(weights_init)
193 | self.lrs.apply(weights_init)
194 | self.sigma.apply(weights_init)
195 | self.color.apply(weights_init)
196 |
197 | def forward(self, vox_feat, img_feat_rgb_dir):
198 | B, N_points, N_views = img_feat_rgb_dir.shape[:-1]
199 | # img_feat = self.agg(img_feat_rgb_dir)
200 | img_feat = torch.cat([img_feat_rgb_dir[..., i, :-4] for i in range(N_views)] , dim=-1)
201 | S = img_feat_rgb_dir.shape[2]
202 | vox_img_feat = torch.cat((vox_feat, img_feat), dim=-1)
203 | x = self.lr0(vox_img_feat)
204 | for i in range(len(self.lrs)):
205 | x = self.lrs[i](x)
206 | sigma = self.sigma(x)
207 | # x = torch.cat((x, vox_img_feat), dim=-1)
208 | # x = x.view(B, -1, 1, x.shape[-1]).repeat(1, 1, S, 1)
209 | # x = torch.cat((x, img_feat_rgb_dir), dim=-1)
210 | color = torch.sigmoid(self.color(x))
211 | return torch.cat([color, sigma], dim=-1)
212 |
213 |
214 |
215 | def weights_init(m):
216 | if isinstance(m, nn.Linear):
217 | nn.init.kaiming_normal_(m.weight.data)
218 | if m.bias is not None:
219 | nn.init.zeros_(m.bias.data)
220 |
221 |
--------------------------------------------------------------------------------
/lib/networks/gefu/res_unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class ResidualConv(nn.Module):
5 | def __init__(self, input_dim, output_dim, stride, padding):
6 | super(ResidualConv, self).__init__()
7 |
8 | self.conv_block = nn.Sequential(
9 | nn.BatchNorm2d(input_dim),
10 | nn.ReLU(),
11 | nn.Conv2d(
12 | input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
13 | ),
14 | nn.BatchNorm2d(output_dim),
15 | nn.ReLU(),
16 | nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
17 | )
18 | self.conv_skip = nn.Sequential(
19 | nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
20 | nn.BatchNorm2d(output_dim),
21 | )
22 |
23 | def forward(self, x):
24 |
25 | return self.conv_block(x) + self.conv_skip(x)
26 |
27 | class Upsample(nn.Module):
28 | def __init__(self, input_dim, output_dim, kernel, stride):
29 | super(Upsample, self).__init__()
30 |
31 | self.upsample = nn.ConvTranspose2d(
32 | input_dim, output_dim, kernel_size=kernel, stride=stride
33 | )
34 |
35 | def forward(self, x):
36 | return self.upsample(x)
37 |
38 |
39 |
40 | class ResUnet(nn.Module):
41 | def __init__(self, channel=3, filters=[16, 32, 64, 128]):
42 | super(ResUnet, self).__init__()
43 |
44 | self.input_layer = nn.Sequential(
45 | nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
46 | nn.BatchNorm2d(filters[0]),
47 | nn.ReLU(),
48 | nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
49 | )
50 | self.input_skip = nn.Sequential(
51 | nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
52 | )
53 |
54 | self.residual_conv_1 = ResidualConv(filters[0], filters[1], 2, 1)
55 | self.residual_conv_2 = ResidualConv(filters[1], filters[2], 2, 1)
56 |
57 | self.bridge = ResidualConv(filters[2], filters[3], 2, 1)
58 |
59 | self.upsample_1 = Upsample(filters[3], filters[3], 2, 2)
60 | # self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], 1, 1)
61 |
62 | # self.upsample_2 = Upsample(filters[2], filters[2], 2, 2)
63 | # self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], 1, 1)
64 |
65 | # self.upsample_3 = Upsample(filters[1], filters[1], 2, 2)
66 | # self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], 1, 1)
67 |
68 | self.output_layer = nn.Sequential(
69 | nn.Conv2d(filters[2]+filters[3], 32, 1, 1)
70 | )
71 |
72 | def forward(self, x):
73 | # Encode
74 | B, S, C, H, W = x.shape
75 | x = x.view(B*S, C, H, W)
76 | x1 = self.input_layer(x) + self.input_skip(x)
77 | x2 = self.residual_conv_1(x1)
78 | x3 = self.residual_conv_2(x2)
79 | # Bridge
80 | x4 = self.bridge(x3)
81 | # Decode
82 | x4 = self.upsample_1(x4)
83 | x5 = torch.cat([x4, x3], dim=1)
84 |
85 | # x6 = self.up_residual_conv1(x5)
86 |
87 | # x6 = self.upsample_2(x6)
88 | # x7 = torch.cat([x6, x2], dim=1)
89 |
90 | # x8 = self.up_residual_conv2(x7)
91 |
92 | # x8 = self.upsample_3(x8)
93 | # x9 = torch.cat([x8, x1], dim=1)
94 |
95 | # x10 = self.up_residual_conv3(x9)
96 |
97 | output = self.output_layer(x5)
98 | output = output.view(B, S, 32, H//4, W//4)
99 | return output
100 |
--------------------------------------------------------------------------------
/lib/networks/make_network.py:
--------------------------------------------------------------------------------
1 | import os
2 | import imp
3 |
4 |
5 | def make_network(cfg):
6 | module = cfg.network_module
7 | path = cfg.network_path
8 | network = imp.load_source(module, path).Network()
9 | return network
10 |
--------------------------------------------------------------------------------
/lib/train/__init__.py:
--------------------------------------------------------------------------------
1 | from .trainers import make_trainer
2 | from .optimizer import make_optimizer
3 | from .scheduler import make_lr_scheduler, set_lr_scheduler
4 | from .recorder import make_recorder
5 |
6 |
--------------------------------------------------------------------------------
/lib/train/losses/gefu.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from lib.config import cfg
5 | from lib.train.losses.vgg_perceptual_loss import VGGPerceptualLoss
6 | from lib.train.losses.ssim_loss import SSIM
7 |
8 |
9 | class NetworkWrapper(nn.Module):
10 | def __init__(self, net, train_loader):
11 | super(NetworkWrapper, self).__init__()
12 | self.device = torch.device('cuda:{}'.format(cfg.local_rank))
13 | self.net = net
14 | self.color_crit = nn.MSELoss(reduction='mean')
15 | self.mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
16 | self.perceptual_loss = VGGPerceptualLoss().to(self.device)
17 |
18 | def forward(self, batch):
19 | output = self.net(batch)
20 |
21 | scalar_stats = {}
22 | loss = 0
23 | for i in range(cfg.gefu.cas_config.num):
24 | color_loss = self.color_crit(batch[f'rgb_{i}'], output[f'rgb_level{i}'])
25 | scalar_stats.update({f'color_mse_{i}': color_loss})
26 | loss += cfg.gefu.cas_config.loss_weight[i] * color_loss
27 |
28 | psnr = -10. * torch.log(color_loss) / torch.log(torch.Tensor([10.]).to(color_loss.device))
29 | scalar_stats.update({f'psnr_{i}': psnr})
30 |
31 | num_patchs = cfg.gefu.cas_config.num_patchs[i]
32 | if cfg.gefu.cas_config.train_img[i]:
33 | render_scale = cfg.gefu.cas_config.render_scale[i]
34 | B, S, C, H, W = batch['src_inps'].shape
35 | H, W = int(H * render_scale), int(W * render_scale)
36 | inp = output[f'rgb_level{i}'].reshape(B, H, W, 3).permute(0, 3, 1, 2)
37 | tar = batch[f'rgb_{i}'].reshape(B, H, W, 3).permute(0, 3, 1, 2)
38 | perceptual_loss = self.perceptual_loss(inp, tar)
39 | loss += 0.1 * perceptual_loss * cfg.gefu.cas_config.loss_weight[i]
40 | scalar_stats.update({f'perceptual_loss_{i}': perceptual_loss.detach()})
41 | ssim = SSIM(window_size = 7)
42 | ssim_loss = 1-ssim(inp, tar)
43 | loss += 0.1 * ssim_loss * cfg.gefu.cas_config.loss_weight[i]
44 | scalar_stats.update({f'ssim_loss_{i}': ssim_loss.detach()})
45 |
46 | elif num_patchs > 0:
47 | patch_size = cfg.gefu.cas_config.patch_size[i]
48 | num_rays = cfg.gefu.cas_config.num_rays[i]
49 | patch_rays = int(patch_size ** 2)
50 | inp = torch.empty((0, 3, patch_size, patch_size)).to(self.device)
51 | tar = torch.empty((0, 3, patch_size, patch_size)).to(self.device)
52 | for j in range(num_patchs):
53 | inp = torch.cat([inp, output[f'rgb_level{i}'][:, num_rays+j*patch_rays:num_rays+(j+1)*patch_rays, :].reshape(-1, patch_size, patch_size, 3).permute(0, 3, 1, 2)])
54 | tar = torch.cat([tar, batch[f'rgb_{i}'][:, num_rays+j*patch_rays:num_rays+(j+1)*patch_rays, :].reshape(-1, patch_size, patch_size, 3).permute(0, 3, 1, 2)])
55 | perceptual_loss = self.perceptual_loss(inp, tar)
56 |
57 | loss += 0.01 * perceptual_loss * cfg.gefu.cas_config.loss_weight[i]
58 | scalar_stats.update({f'perceptual_loss_{i}': perceptual_loss.detach()})
59 |
60 | scalar_stats.update({'loss': loss})
61 | image_stats = {}
62 |
63 | return output, loss, scalar_stats, image_stats
64 |
65 |
--------------------------------------------------------------------------------
/lib/train/losses/nerf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from lib.utils import net_utils
4 | from lib.config import cfg
5 |
6 | class NetworkWrapper(nn.Module):
7 | def __init__(self, net):
8 | super(NetworkWrapper, self).__init__()
9 | self.net = net
10 | self.color_crit = nn.MSELoss(reduction='mean')
11 | self.mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
12 |
13 | def forward(self, batch):
14 | output = self.net(batch)
15 | scalar_stats = {}
16 | loss = 0
17 | color_loss = self.color_crit(output['rgb_0'], batch['rgb'])
18 | scalar_stats.update({'color_mse_0': color_loss})
19 | loss += color_loss
20 |
21 | psnr = -10. * torch.log(color_loss.detach()) / \
22 | torch.log(torch.Tensor([10.]).to(color_loss.device))
23 | scalar_stats.update({'psnr_0': psnr})
24 |
25 | if len(cfg.task_arg.cascade_samples) > 1:
26 | color_loss = self.color_crit(output['rgb_1'], batch['rgb'])
27 | scalar_stats.update({'color_mse_1': color_loss})
28 | loss += color_loss
29 |
30 | psnr = -10. * torch.log(color_loss.detach()) / \
31 | torch.log(torch.Tensor([10.]).to(color_loss.device))
32 | scalar_stats.update({'psnr_1': psnr})
33 |
34 | scalar_stats.update({'loss': loss})
35 | image_stats = {}
36 |
37 | return output, loss, scalar_stats, image_stats
38 |
--------------------------------------------------------------------------------
/lib/train/losses/ssim_loss.py:
--------------------------------------------------------------------------------
1 | ##### the code is borrowed from https://github.com/Po-Hsun-Su/pytorch-ssim #####
2 | import torch
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | import numpy as np
6 | from math import exp
7 |
8 | def gaussian(window_size, sigma):
9 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
10 | return gauss/gauss.sum()
11 |
12 | def create_window(window_size, channel):
13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
15 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
16 | return window
17 |
18 | def _ssim(img1, img2, window, window_size, channel, size_average = True):
19 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
20 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
21 |
22 | mu1_sq = mu1.pow(2)
23 | mu2_sq = mu2.pow(2)
24 | mu1_mu2 = mu1*mu2
25 |
26 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
27 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
28 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
29 |
30 | C1 = 0.01**2
31 | C2 = 0.03**2
32 |
33 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
34 |
35 | if size_average:
36 | return ssim_map.mean()
37 | else:
38 | return ssim_map.mean(1).mean(1).mean(1)
39 |
40 | class SSIM(torch.nn.Module):
41 | def __init__(self, window_size = 11, size_average = True):
42 | super(SSIM, self).__init__()
43 | self.window_size = window_size
44 | self.size_average = size_average
45 | self.channel = 1
46 | self.window = create_window(window_size, self.channel)
47 |
48 | def forward(self, img1, img2):
49 | (_, channel, _, _) = img1.size()
50 |
51 | if channel == self.channel and self.window.data.type() == img1.data.type():
52 | window = self.window
53 | else:
54 | window = create_window(self.window_size, channel)
55 |
56 | if img1.is_cuda:
57 | window = window.cuda(img1.get_device())
58 | window = window.type_as(img1)
59 |
60 | self.window = window
61 | self.channel = channel
62 |
63 |
64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
65 |
66 | def ssim(img1, img2, window_size = 11, size_average = True):
67 | (_, channel, _, _) = img1.size()
68 | window = create_window(window_size, channel)
69 |
70 | if img1.is_cuda:
71 | window = window.cuda(img1.get_device())
72 | window = window.type_as(img1)
73 |
74 | return _ssim(img1, img2, window, window_size, channel, size_average)
75 |
--------------------------------------------------------------------------------
/lib/train/losses/vgg_perceptual_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 |
4 | class VGGPerceptualLoss(torch.nn.Module):
5 | def __init__(self, resize=False):
6 | super(VGGPerceptualLoss, self).__init__()
7 | blocks = []
8 | blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
9 | blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
10 | blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
11 | blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
12 | for bl in blocks:
13 | for p in bl.parameters():
14 | p.requires_grad = False
15 | self.blocks = torch.nn.ModuleList(blocks)
16 | self.transform = torch.nn.functional.interpolate
17 | self.resize = resize
18 | self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
19 | self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
20 |
21 | def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]):
22 | if input.shape[1] != 3:
23 | input = input.repeat(1, 3, 1, 1)
24 | target = target.repeat(1, 3, 1, 1)
25 | input = (input-self.mean) / self.std
26 | target = (target-self.mean) / self.std
27 | if self.resize:
28 | input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
29 | target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
30 | loss = 0.0
31 | x = input
32 | y = target
33 | for i, block in enumerate(self.blocks):
34 | x = block(x)
35 | y = block(y)
36 | if i in feature_layers:
37 | loss += torch.nn.functional.l1_loss(x, y)
38 | if i in style_layers:
39 | act_x = x.reshape(x.shape[0], x.shape[1], -1)
40 | act_y = y.reshape(y.shape[0], y.shape[1], -1)
41 | gram_x = act_x @ act_x.permute(0, 2, 1)
42 | gram_y = act_y @ act_y.permute(0, 2, 1)
43 | loss += torch.nn.functional.l1_loss(gram_x, gram_y)
44 | return loss
45 |
--------------------------------------------------------------------------------
/lib/train/optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from lib.utils.optimizer.radam import RAdam
3 |
4 |
5 | _optimizer_factory = {
6 | 'adam': torch.optim.Adam,
7 | 'radam': RAdam,
8 | 'sgd': torch.optim.SGD
9 | }
10 |
11 |
12 | def make_optimizer(cfg, net):
13 | params = []
14 | lr = cfg.train.lr
15 | weight_decay = cfg.train.weight_decay
16 | eps = cfg.train.eps
17 |
18 | for key, value in net.named_parameters():
19 | if not value.requires_grad:
20 | continue
21 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay, "eps": eps}]
22 |
23 | if 'adam' in cfg.train.optim:
24 | optimizer = _optimizer_factory[cfg.train.optim](params, lr, weight_decay=weight_decay, eps=eps)
25 | else:
26 | optimizer = _optimizer_factory[cfg.train.optim](params, lr, momentum=0.9)
27 |
28 | return optimizer
29 |
--------------------------------------------------------------------------------
/lib/train/recorder.py:
--------------------------------------------------------------------------------
1 | from collections import deque, defaultdict
2 | import torch
3 | from tensorboardX import SummaryWriter
4 | import os
5 | from lib.config.config import cfg
6 |
7 | from termcolor import colored
8 |
9 |
10 | class SmoothedValue(object):
11 | """Track a series of values and provide access to smoothed values over a
12 | window or the global series average.
13 | """
14 |
15 | def __init__(self, window_size=20):
16 | self.deque = deque(maxlen=window_size)
17 | self.total = 0.0
18 | self.count = 0
19 |
20 | def update(self, value):
21 | self.deque.append(value)
22 | self.count += 1
23 | self.total += value
24 |
25 | @property
26 | def median(self):
27 | d = torch.tensor(list(self.deque))
28 | return d.median().item()
29 |
30 | @property
31 | def avg(self):
32 | d = torch.tensor(list(self.deque))
33 | return d.mean().item()
34 |
35 | @property
36 | def global_avg(self):
37 | return self.total / self.count
38 |
39 |
40 | def process_volsdf(image_stats):
41 | for k, v in image_stats.items():
42 | image_stats[k] = torch.clamp(v[0].permute(2, 0, 1), min=0., max=1.)
43 | return image_stats
44 |
45 | process_neus = process_volsdf
46 |
47 | class Recorder(object):
48 | def __init__(self, cfg):
49 | if cfg.local_rank > 0:
50 | return
51 |
52 | log_dir = cfg.record_dir
53 | if not cfg.resume:
54 | print(colored('remove contents of directory %s' % log_dir, 'red'))
55 | os.system('rm -r %s/*' % log_dir)
56 | self.writer = SummaryWriter(log_dir=log_dir)
57 |
58 | # scalars
59 | self.epoch = 0
60 | self.step = 0
61 | self.loss_stats = defaultdict(SmoothedValue)
62 | self.batch_time = SmoothedValue()
63 | self.data_time = SmoothedValue()
64 |
65 | # images
66 | self.image_stats = defaultdict(object)
67 | if 'process_' + cfg.task in globals():
68 | self.processor = globals()['process_' + cfg.task]
69 | else:
70 | self.processor = None
71 |
72 | def update_loss_stats(self, loss_dict):
73 | if cfg.local_rank > 0:
74 | return
75 | for k, v in loss_dict.items():
76 | self.loss_stats[k].update(v.detach().cpu())
77 |
78 | def update_image_stats(self, image_stats):
79 | if cfg.local_rank > 0:
80 | return
81 | if self.processor is None:
82 | return
83 | image_stats = self.processor(image_stats)
84 | for k, v in image_stats.items():
85 | self.image_stats[k] = v.detach().cpu()
86 |
87 | def record(self, prefix, step=-1, loss_stats=None, image_stats=None):
88 | if cfg.local_rank > 0:
89 | return
90 |
91 | pattern = prefix + '/{}'
92 | step = step if step >= 0 else self.step
93 | loss_stats = loss_stats if loss_stats else self.loss_stats
94 |
95 | for k, v in loss_stats.items():
96 | if isinstance(v, SmoothedValue):
97 | self.writer.add_scalar(pattern.format(k), v.median, step)
98 | else:
99 | self.writer.add_scalar(pattern.format(k), v, step)
100 |
101 | if self.processor is None:
102 | return
103 | image_stats = self.processor(image_stats) if image_stats else self.image_stats
104 | for k, v in image_stats.items():
105 | self.writer.add_image(pattern.format(k), v, step)
106 |
107 | def state_dict(self):
108 | if cfg.local_rank > 0:
109 | return
110 | scalar_dict = {}
111 | scalar_dict['step'] = self.step
112 | return scalar_dict
113 |
114 | def load_state_dict(self, scalar_dict):
115 | if cfg.local_rank > 0:
116 | return
117 | self.step = scalar_dict['step']
118 |
119 | def __str__(self):
120 | if cfg.local_rank > 0:
121 | return
122 | loss_state = []
123 | for k, v in self.loss_stats.items():
124 | loss_state.append('{}: {:.4f}'.format(k, v.avg))
125 | loss_state = ' '.join(loss_state)
126 |
127 | recording_state = ' '.join(['epoch: {}', 'step: {}', '{}', 'data: {:.4f}', 'batch: {:.4f}'])
128 | return recording_state.format(self.epoch, self.step, loss_state, self.data_time.avg, self.batch_time.avg)
129 |
130 |
131 | def make_recorder(cfg):
132 | return Recorder(cfg)
133 |
--------------------------------------------------------------------------------
/lib/train/scheduler.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 | from lib.utils.optimizer.lr_scheduler import WarmupMultiStepLR, MultiStepLR, ExponentialLR
3 |
4 |
5 | def make_lr_scheduler(cfg, optimizer):
6 | cfg_scheduler = cfg.train.scheduler
7 | if cfg_scheduler.type == 'multi_step':
8 | scheduler = MultiStepLR(optimizer,
9 | milestones=cfg_scheduler.milestones,
10 | gamma=cfg_scheduler.gamma)
11 | elif cfg_scheduler.type == 'exponential':
12 | scheduler = ExponentialLR(optimizer,
13 | decay_epochs=cfg_scheduler.decay_epochs,
14 | gamma=cfg_scheduler.gamma)
15 | return scheduler
16 |
17 |
18 | def set_lr_scheduler(cfg, scheduler):
19 | cfg_scheduler = cfg.train.scheduler
20 | if cfg_scheduler.type == 'multi_step':
21 | scheduler.milestones = Counter(cfg_scheduler.milestones)
22 | elif cfg_scheduler.type == 'exponential':
23 | scheduler.decay_epochs = cfg_scheduler.decay_epochs
24 | scheduler.gamma = cfg_scheduler.gamma
25 |
--------------------------------------------------------------------------------
/lib/train/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | from .make_trainer import make_trainer
2 |
--------------------------------------------------------------------------------
/lib/train/trainers/make_trainer.py:
--------------------------------------------------------------------------------
1 | from .trainer import Trainer
2 | import imp
3 |
4 |
5 | def _wrapper_factory(cfg, network, train_loader=None):
6 | module = cfg.loss_module
7 | path = cfg.loss_path
8 | network_wrapper = imp.load_source(module, path).NetworkWrapper(network, train_loader)
9 | return network_wrapper
10 |
11 |
12 | def make_trainer(cfg, network, train_loader=None):
13 | network = _wrapper_factory(cfg, network, train_loader)
14 | return Trainer(network)
15 |
--------------------------------------------------------------------------------
/lib/train/trainers/trainer.py:
--------------------------------------------------------------------------------
1 | import time
2 | import datetime
3 | import torch
4 | import tqdm
5 | from torch.nn import DataParallel
6 | from torch.nn.parallel import DistributedDataParallel
7 | from lib.config import cfg
8 | from lib.utils.data_utils import to_cuda
9 |
10 |
11 | class Trainer(object):
12 | def __init__(self, network):
13 | device = torch.device('cuda:{}'.format(cfg.local_rank))
14 | network = network.to(device)
15 | if cfg.distributed:
16 | network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(network)
17 | network = DistributedDataParallel(
18 | network,
19 | device_ids=[cfg.local_rank],
20 | output_device=cfg.local_rank,
21 | find_unused_parameters=True
22 | )
23 | self.network = network
24 | self.local_rank = cfg.local_rank
25 | self.device = device
26 | self.global_step = 0
27 |
28 | def reduce_loss_stats(self, loss_stats):
29 | reduced_losses = {k: torch.mean(v) for k, v in loss_stats.items()}
30 | return reduced_losses
31 |
32 | def to_cuda(self, batch):
33 | for k in batch:
34 | if isinstance(batch[k], tuple) or isinstance(batch[k], list):
35 | #batch[k] = [b.cuda() for b in batch[k]]
36 | batch[k] = [b.to(self.device) for b in batch[k]]
37 | elif isinstance(batch[k], dict):
38 | batch[k] = {key: self.to_cuda(batch[k][key]) for key in batch[k]}
39 | else:
40 | # batch[k] = batch[k].cuda()
41 | batch[k] = batch[k].to(self.device)
42 | return batch
43 |
44 | def train(self, epoch, data_loader, optimizer, recorder):
45 | max_iter = len(data_loader)
46 | self.network.train()
47 | end = time.time()
48 | if self.global_step == 0:
49 | self.global_step = cfg.ep_iter * epoch
50 | for iteration, batch in enumerate(data_loader):
51 | data_time = time.time() - end
52 | iteration = iteration + 1
53 |
54 | batch = to_cuda(batch, self.device)
55 | batch['step'] = 0
56 | output, loss, loss_stats, image_stats = self.network(batch)
57 | # training stage: loss; optimizer; scheduler
58 | loss = loss.mean()
59 | optimizer.zero_grad()
60 | loss.backward()
61 | torch.nn.utils.clip_grad_value_(self.network.parameters(), 40)
62 | optimizer.step()
63 |
64 | if cfg.local_rank > 0:
65 | continue
66 |
67 | # data recording stage: loss_stats, time, image_stats
68 | recorder.step += 1
69 |
70 | loss_stats = self.reduce_loss_stats(loss_stats)
71 | recorder.update_loss_stats(loss_stats)
72 |
73 | batch_time = time.time() - end
74 | end = time.time()
75 | recorder.batch_time.update(batch_time)
76 | recorder.data_time.update(data_time)
77 |
78 | self.global_step += 1
79 | if iteration % cfg.log_interval == 0 or iteration == (max_iter - 1):
80 | # print training state
81 | eta_seconds = recorder.batch_time.global_avg * (max_iter - iteration)
82 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
83 | lr = optimizer.param_groups[0]['lr']
84 | memory = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
85 |
86 | training_state = ' '.join(['eta: {}', '{}', 'lr: {:.6f}', 'max_mem: {:.0f}'])
87 | training_state = training_state.format(eta_string, str(recorder), lr, memory)
88 | print(training_state)
89 |
90 | # record loss_stats and image_dict
91 | recorder.update_image_stats(image_stats)
92 | recorder.record('train')
93 |
94 | def val(self, epoch, data_loader, evaluator=None, recorder=None):
95 | self.network.eval()
96 | torch.cuda.empty_cache()
97 | val_loss_stats = {}
98 | image_stats = {}
99 | data_size = len(data_loader)
100 | for batch in tqdm.tqdm(data_loader):
101 | batch = to_cuda(batch, self.device)
102 | batch['step'] = recorder.step
103 | with torch.no_grad():
104 | output, loss, loss_stats, _ = self.network(batch)
105 | if evaluator is not None:
106 | image_stats_ = evaluator.evaluate(output, batch)
107 | if image_stats_ is not None:
108 | image_stats.update(image_stats_)
109 |
110 | loss_stats = self.reduce_loss_stats(loss_stats)
111 | for k, v in loss_stats.items():
112 | val_loss_stats.setdefault(k, 0)
113 | val_loss_stats[k] += v
114 |
115 | loss_state = []
116 | for k in val_loss_stats.keys():
117 | val_loss_stats[k] /= data_size
118 | loss_state.append('{}: {:.4f}'.format(k, val_loss_stats[k]))
119 | print(loss_state)
120 |
121 | if evaluator is not None:
122 | result = evaluator.summarize()
123 | val_loss_stats.update(result)
124 |
125 | if recorder:
126 | recorder.record('val', epoch, val_loss_stats, image_stats)
127 | return result
--------------------------------------------------------------------------------
/lib/utils/base_utils.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import os
3 | import numpy as np
4 | import cv2
5 | import time
6 | from termcolor import colored
7 | import importlib
8 | import torch.distributed as dist
9 | import math
10 |
11 | class perf_timer:
12 | def __init__(self, msg="Elapsed time: {}s", logf=lambda x: print(colored(x, 'yellow')), sync_cuda=True, use_ms=False, disabled=False):
13 | self.logf = logf
14 | self.msg = msg
15 | self.sync_cuda = sync_cuda
16 | self.use_ms = use_ms
17 | self.disabled = disabled
18 |
19 | self.loggedtime = None
20 |
21 | def __enter__(self,):
22 | if self.sync_cuda:
23 | torch.cuda.synchronize()
24 | self.start = time.perf_counter()
25 |
26 | def __exit__(self, exc_type, exc_value, traceback):
27 | if self.sync_cuda:
28 | torch.cuda.synchronize()
29 | self.logtime(self.msg)
30 |
31 | def logtime(self, msg=None, logf=None):
32 | if self.disabled:
33 | return
34 | # SAME CLASS, DIFFERENT FUNCTIONALITY, is this good?
35 | # call the logger for timing code sections
36 | if self.sync_cuda:
37 | torch.cuda.synchronize()
38 |
39 | # always remember current time
40 | prev = self.loggedtime
41 | self.loggedtime = time.perf_counter()
42 |
43 | # print it if we've remembered previous time
44 | if prev is not None and msg:
45 | logf = logf or self.logf
46 | diff = self.loggedtime-prev
47 | diff *= 1000 if self.use_ms else 1
48 | logf(msg.format(diff))
49 |
50 | return self.loggedtime
51 |
52 | def read_pickle(pkl_path):
53 | with open(pkl_path, 'rb') as f:
54 | return pickle.load(f)
55 |
56 |
57 | def save_pickle(data, pkl_path):
58 | os.system('mkdir -p {}'.format(os.path.dirname(pkl_path)))
59 | with open(pkl_path, 'wb') as f:
60 | pickle.dump(data, f)
61 |
62 |
63 | def project(xyz, K, RT):
64 | """
65 | xyz: [N, 3]
66 | K: [3, 3]
67 | RT: [3, 4]
68 | """
69 | xyz = np.dot(xyz, RT[:, :3].T) + RT[:, 3:].T
70 | xyz = np.dot(xyz, K.T)
71 | xy = xyz[:, :2] / xyz[:, 2:]
72 | return xy
73 |
74 | def get_bbox_2d(bbox, K, RT):
75 | pts = np.array([[bbox[0, 0], bbox[0, 1], bbox[0, 2]],
76 | [bbox[0, 0], bbox[0, 1], bbox[1, 2]],
77 | [bbox[0, 0], bbox[1, 1], bbox[0, 2]],
78 | [bbox[0, 0], bbox[1, 1], bbox[1, 2]],
79 | [bbox[1, 0], bbox[0, 1], bbox[0, 2]],
80 | [bbox[1, 0], bbox[0, 1], bbox[1, 2]],
81 | [bbox[1, 0], bbox[1, 1], bbox[0, 2]],
82 | [bbox[1, 0], bbox[1, 1], bbox[1, 2]],
83 | ])
84 | pts_2d = project(pts, K, RT)
85 | return [pts_2d[:, 0].min(), pts_2d[:, 1].min(), pts_2d[:, 0].max(), pts_2d[:, 1].max()]
86 |
87 |
88 | def get_bound_corners(bounds):
89 | min_x, min_y, min_z = bounds[0]
90 | max_x, max_y, max_z = bounds[1]
91 | corners_3d = np.array([
92 | [min_x, min_y, min_z],
93 | [min_x, min_y, max_z],
94 | [min_x, max_y, min_z],
95 | [min_x, max_y, max_z],
96 | [max_x, min_y, min_z],
97 | [max_x, min_y, max_z],
98 | [max_x, max_y, min_z],
99 | [max_x, max_y, max_z],
100 | ])
101 | return corners_3d
102 |
103 | def get_bound_2d_mask(bounds, K, pose, H, W):
104 | corners_3d = get_bound_corners(bounds)
105 | corners_2d = project(corners_3d, K, pose)
106 | corners_2d = np.round(corners_2d).astype(int)
107 | mask = np.zeros((H, W), dtype=np.uint8)
108 | cv2.fillPoly(mask, [corners_2d[[0, 1, 3, 2, 0]]], 1)
109 | cv2.fillPoly(mask, [corners_2d[[4, 5, 7, 6, 5]]], 1)
110 | cv2.fillPoly(mask, [corners_2d[[0, 1, 5, 4, 0]]], 1)
111 | cv2.fillPoly(mask, [corners_2d[[2, 3, 7, 6, 2]]], 1)
112 | cv2.fillPoly(mask, [corners_2d[[0, 2, 6, 4, 0]]], 1)
113 | cv2.fillPoly(mask, [corners_2d[[1, 3, 7, 5, 1]]], 1)
114 | return mask
115 |
116 | def load_object(module_name, module_args, **extra_args):
117 | module_path = '.'.join(module_name.split('.')[:-1])
118 | module = importlib.import_module(module_path)
119 | name = module_name.split('.')[-1]
120 | obj = getattr(module, name)(**extra_args, **module_args)
121 | return obj
122 |
123 |
124 |
125 | def get_indices(length):
126 | num_replicas = dist.get_world_size()
127 | rank = dist.get_rank()
128 | num_samples = int(math.ceil(length * 1.0 / num_replicas))
129 | total_size = num_samples * num_replicas
130 | indices = np.arange(length).tolist()
131 | indices += indices[: (total_size - len(indices))]
132 | offset = num_samples * rank
133 | indices = indices[offset:offset+num_samples]
134 | return indices
135 |
136 |
137 | class DotDict(dict):
138 | """
139 | a dictionary that supports dot notation
140 | as well as dictionary access notation
141 | usage: d = DotDict() or d = DotDict({'val1':'first'})
142 | set attributes: d.val2 = 'second' or d['val2'] = 'second'
143 | get attributes: d.val2 or d['val2']
144 | """
145 | __getattr__ = dict.__getitem__
146 | __setattr__ = dict.__setitem__
147 | __delattr__ = dict.__delitem__
148 |
149 | def __init__(self, dct=None):
150 | if dct is not None:
151 | for key, value in dct.items():
152 | if hasattr(value, 'keys'):
153 | value = DotDict(value)
154 | # self[key] = value
155 |
156 |
157 |
158 |
--------------------------------------------------------------------------------
/lib/utils/data_config.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | mean_rgb = np.array([0.485, 0.456, 0.406]).reshape(1, 1, 3).astype(np.float32)
3 | std_rgb = np.array([0.229, 0.224, 0.225]).reshape(1, 1, 3).astype(np.float32)
4 |
--------------------------------------------------------------------------------
/lib/utils/img_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from matplotlib import cm
3 | import matplotlib.pyplot as plt
4 | import matplotlib.patches as patches
5 | import numpy as np
6 | import cv2
7 | from lib.utils import data_config
8 |
9 |
10 |
11 | def unnormalize_img(img, mean, std):
12 | """
13 | img: [3, h, w]
14 | """
15 | img = img.detach().cpu().clone()
16 | # img = img / 255.
17 | img *= torch.tensor(std).view(3, 1, 1)
18 | img += torch.tensor(mean).view(3, 1, 1)
19 | min_v = torch.min(img)
20 | img = (img - min_v) / (torch.max(img) - min_v)
21 | return img
22 |
23 |
24 | def bgr_to_rgb(img):
25 | return img[:, :, [2, 1, 0]]
26 |
27 |
28 | def horizon_concate(inp0, inp1):
29 | h0, w0 = inp0.shape[:2]
30 | h1, w1 = inp1.shape[:2]
31 | if inp0.ndim == 3:
32 | inp = np.zeros((max(h0, h1), w0 + w1, 3), dtype=inp0.dtype)
33 | inp[:h0, :w0, :] = inp0
34 | inp[:h1, w0:(w0 + w1), :] = inp1
35 | else:
36 | inp = np.zeros((max(h0, h1), w0 + w1), dtype=inp0.dtype)
37 | inp[:h0, :w0] = inp0
38 | inp[:h1, w0:(w0 + w1)] = inp1
39 | return inp
40 |
41 |
42 | def vertical_concate(inp0, inp1):
43 | h0, w0 = inp0.shape[:2]
44 | h1, w1 = inp1.shape[:2]
45 | if inp0.ndim == 3:
46 | inp = np.zeros((h0 + h1, max(w0, w1), 3), dtype=inp0.dtype)
47 | inp[:h0, :w0, :] = inp0
48 | inp[h0:(h0 + h1), :w1, :] = inp1
49 | else:
50 | inp = np.zeros((h0 + h1, max(w0, w1)), dtype=inp0.dtype)
51 | inp[:h0, :w0] = inp0
52 | inp[h0:(h0 + h1), :w1] = inp1
53 | return inp
54 |
55 |
56 | def transparent_cmap(cmap):
57 | """Copy colormap and set alpha values"""
58 | mycmap = cmap
59 | mycmap._init()
60 | mycmap._lut[:,-1] = 0.3
61 | return mycmap
62 |
63 | cmap = transparent_cmap(plt.get_cmap('jet'))
64 |
65 |
66 | def set_grid(ax, h, w, interval=8):
67 | ax.set_xticks(np.arange(0, w, interval))
68 | ax.set_yticks(np.arange(0, h, interval))
69 | ax.grid()
70 | ax.set_yticklabels([])
71 | ax.set_xticklabels([])
72 |
73 |
74 | color_list = np.array(
75 | [
76 | 0.000, 0.447, 0.741,
77 | 0.850, 0.325, 0.098,
78 | 0.929, 0.694, 0.125,
79 | 0.494, 0.184, 0.556,
80 | 0.466, 0.674, 0.188,
81 | 0.301, 0.745, 0.933,
82 | 0.635, 0.078, 0.184,
83 | 0.300, 0.300, 0.300,
84 | 0.600, 0.600, 0.600,
85 | 1.000, 0.000, 0.000,
86 | 1.000, 0.500, 0.000,
87 | 0.749, 0.749, 0.000,
88 | 0.000, 1.000, 0.000,
89 | 0.000, 0.000, 1.000,
90 | 0.667, 0.000, 1.000,
91 | 0.333, 0.333, 0.000,
92 | 0.333, 0.667, 0.000,
93 | 0.333, 1.000, 0.000,
94 | 0.667, 0.333, 0.000,
95 | 0.667, 0.667, 0.000,
96 | 0.667, 1.000, 0.000,
97 | 1.000, 0.333, 0.000,
98 | 1.000, 0.667, 0.000,
99 | 1.000, 1.000, 0.000,
100 | 0.000, 0.333, 0.500,
101 | 0.000, 0.667, 0.500,
102 | 0.000, 1.000, 0.500,
103 | 0.333, 0.000, 0.500,
104 | 0.333, 0.333, 0.500,
105 | 0.333, 0.667, 0.500,
106 | 0.333, 1.000, 0.500,
107 | 0.667, 0.000, 0.500,
108 | 0.667, 0.333, 0.500,
109 | 0.667, 0.667, 0.500,
110 | 0.667, 1.000, 0.500,
111 | 1.000, 0.000, 0.500,
112 | 1.000, 0.333, 0.500,
113 | 1.000, 0.667, 0.500,
114 | 1.000, 1.000, 0.500,
115 | 0.000, 0.333, 1.000,
116 | 0.000, 0.667, 1.000,
117 | 0.000, 1.000, 1.000,
118 | 0.333, 0.000, 1.000,
119 | 0.333, 0.333, 1.000,
120 | 0.333, 0.667, 1.000,
121 | 0.333, 1.000, 1.000,
122 | 0.667, 0.000, 1.000,
123 | 0.667, 0.333, 1.000,
124 | 0.667, 0.667, 1.000,
125 | 0.667, 1.000, 1.000,
126 | 1.000, 0.000, 1.000,
127 | 1.000, 0.333, 1.000,
128 | 1.000, 0.667, 1.000,
129 | 0.167, 0.000, 0.000,
130 | 0.333, 0.000, 0.000,
131 | 0.500, 0.000, 0.000,
132 | 0.667, 0.000, 0.000,
133 | 0.833, 0.000, 0.000,
134 | 1.000, 0.000, 0.000,
135 | 0.000, 0.167, 0.000,
136 | 0.000, 0.333, 0.000,
137 | 0.000, 0.500, 0.000,
138 | 0.000, 0.667, 0.000,
139 | 0.000, 0.833, 0.000,
140 | 0.000, 1.000, 0.000,
141 | 0.000, 0.000, 0.167,
142 | 0.000, 0.000, 0.333,
143 | 0.000, 0.000, 0.500,
144 | 0.000, 0.000, 0.667,
145 | 0.000, 0.000, 0.833,
146 | 0.000, 0.000, 1.000,
147 | 0.000, 0.000, 0.000,
148 | 0.143, 0.143, 0.143,
149 | 0.286, 0.286, 0.286,
150 | 0.429, 0.429, 0.429,
151 | 0.571, 0.571, 0.571,
152 | 0.714, 0.714, 0.714,
153 | 0.857, 0.857, 0.857,
154 | 1.000, 1.000, 1.000,
155 | 0.50, 0.5, 0
156 | ]
157 | ).astype(np.float32)
158 | colors = color_list.reshape((-1, 3)) * 255
159 | colors = np.array(colors, dtype=np.uint8).reshape(len(colors), 1, 1, 3)
160 |
161 | def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET):
162 | """
163 | depth: (H, W)
164 | """
165 | x = np.nan_to_num(depth) # change nan to 0
166 | if minmax is None:
167 | mi = np.min(x[x>0]) # get minimum positive depth (ignore background)
168 | ma = np.max(x)
169 | else:
170 | mi,ma = minmax
171 |
172 | x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1
173 | x = (255*x).astype(np.uint8)
174 | x_ = cv2.applyColorMap(x, cmap)
175 | return x_, [mi,ma]
176 |
--------------------------------------------------------------------------------
/lib/utils/mask_utils.py:
--------------------------------------------------------------------------------
1 | def get_class_ids_from_labels(labels):
2 | ids = []
3 | for l in labels:
4 | ids.append(label_id_mapping_ade20k[l])
5 | return ids
6 |
7 | def get_label_id_mapping(use_human_mask=False):
8 | if use_human_mask:
9 | return label_id_mapping_human
10 | else:
11 | return label_id_mapping_ade20k
12 |
13 | id_label_mapping_human = {
14 | 0: 'non_person',
15 | 1: 'person'
16 | }
17 |
18 | label_id_mapping_human = {
19 | 'non_person': 0,
20 | 'person': 1
21 | }
22 |
23 | id_label_mapping_ade20k = {0: 'wall',
24 | 1: 'building',
25 | 2: 'sky',
26 | 3: 'floor',
27 | 4: 'tree',
28 | 5: 'ceiling',
29 | 6: 'road',
30 | 7: 'bed ',
31 | 8: 'windowpane',
32 | 9: 'grass',
33 | 10: 'cabinet',
34 | 11: 'sidewalk',
35 | 12: 'person',
36 | 13: 'earth',
37 | 14: 'door',
38 | 15: 'table',
39 | 16: 'mountain',
40 | 17: 'plant',
41 | 18: 'curtain',
42 | 19: 'chair',
43 | 20: 'car',
44 | 21: 'water',
45 | 22: 'painting',
46 | 23: 'sofa',
47 | 24: 'shelf',
48 | 25: 'house',
49 | 26: 'sea',
50 | 27: 'mirror',
51 | 28: 'rug',
52 | 29: 'field',
53 | 30: 'armchair',
54 | 31: 'seat',
55 | 32: 'fence',
56 | 33: 'desk',
57 | 34: 'rock',
58 | 35: 'wardrobe',
59 | 36: 'lamp',
60 | 37: 'bathtub',
61 | 38: 'railing',
62 | 39: 'cushion',
63 | 40: 'base',
64 | 41: 'box',
65 | 42: 'column',
66 | 43: 'signboard',
67 | 44: 'chest of drawers',
68 | 45: 'counter',
69 | 46: 'sand',
70 | 47: 'sink',
71 | 48: 'skyscraper',
72 | 49: 'fireplace',
73 | 50: 'refrigerator',
74 | 51: 'grandstand',
75 | 52: 'path',
76 | 53: 'stairs',
77 | 54: 'runway',
78 | 55: 'case',
79 | 56: 'pool table',
80 | 57: 'pillow',
81 | 58: 'screen door',
82 | 59: 'stairway',
83 | 60: 'river',
84 | 61: 'bridge',
85 | 62: 'bookcase',
86 | 63: 'blind',
87 | 64: 'coffee table',
88 | 65: 'toilet',
89 | 66: 'flower',
90 | 67: 'book',
91 | 68: 'hill',
92 | 69: 'bench',
93 | 70: 'countertop',
94 | 71: 'stove',
95 | 72: 'palm',
96 | 73: 'kitchen island',
97 | 74: 'computer',
98 | 75: 'swivel chair',
99 | 76: 'boat',
100 | 77: 'bar',
101 | 78: 'arcade machine',
102 | 79: 'hovel',
103 | 80: 'bus',
104 | 81: 'towel',
105 | 82: 'light',
106 | 83: 'truck',
107 | 84: 'tower',
108 | 85: 'chandelier',
109 | 86: 'awning',
110 | 87: 'streetlight',
111 | 88: 'booth',
112 | 89: 'television receiver',
113 | 90: 'airplane',
114 | 91: 'dirt track',
115 | 92: 'apparel',
116 | 93: 'pole',
117 | 94: 'land',
118 | 95: 'bannister',
119 | 96: 'escalator',
120 | 97: 'ottoman',
121 | 98: 'bottle',
122 | 99: 'buffet',
123 | 100: 'poster',
124 | 101: 'stage',
125 | 102: 'van',
126 | 103: 'ship',
127 | 104: 'fountain',
128 | 105: 'conveyer belt',
129 | 106: 'canopy',
130 | 107: 'washer',
131 | 108: 'plaything',
132 | 109: 'swimming pool',
133 | 110: 'stool',
134 | 111: 'barrel',
135 | 112: 'basket',
136 | 113: 'waterfall',
137 | 114: 'tent',
138 | 115: 'bag',
139 | 116: 'minibike',
140 | 117: 'cradle',
141 | 118: 'oven',
142 | 119: 'ball',
143 | 120: 'food',
144 | 121: 'step',
145 | 122: 'tank',
146 | 123: 'trade name',
147 | 124: 'microwave',
148 | 125: 'pot',
149 | 126: 'animal',
150 | 127: 'bicycle',
151 | 128: 'lake',
152 | 129: 'dishwasher',
153 | 130: 'screen',
154 | 131: 'blanket',
155 | 132: 'sculpture',
156 | 133: 'hood',
157 | 134: 'sconce',
158 | 135: 'vase',
159 | 136: 'traffic light',
160 | 137: 'tray',
161 | 138: 'ashcan',
162 | 139: 'fan',
163 | 140: 'pier',
164 | 141: 'crt screen',
165 | 142: 'plate',
166 | 143: 'monitor',
167 | 144: 'bulletin board',
168 | 145: 'shower',
169 | 146: 'radiator',
170 | 147: 'glass',
171 | 148: 'clock',
172 | 149: 'flag'}
173 |
174 | label_id_mapping_ade20k = {'airplane': 90,
175 | 'animal': 126,
176 | 'apparel': 92,
177 | 'arcade machine': 78,
178 | 'armchair': 30,
179 | 'ashcan': 138,
180 | 'awning': 86,
181 | 'bag': 115,
182 | 'ball': 119,
183 | 'bannister': 95,
184 | 'bar': 77,
185 | 'barrel': 111,
186 | 'base': 40,
187 | 'basket': 112,
188 | 'bathtub': 37,
189 | 'bed ': 7,
190 | 'bench': 69,
191 | 'bicycle': 127,
192 | 'blanket': 131,
193 | 'blind': 63,
194 | 'boat': 76,
195 | 'book': 67,
196 | 'bookcase': 62,
197 | 'booth': 88,
198 | 'bottle': 98,
199 | 'box': 41,
200 | 'bridge': 61,
201 | 'buffet': 99,
202 | 'building': 1,
203 | 'bulletin board': 144,
204 | 'bus': 80,
205 | 'cabinet': 10,
206 | 'canopy': 106,
207 | 'car': 20,
208 | 'case': 55,
209 | 'ceiling': 5,
210 | 'chair': 19,
211 | 'chandelier': 85,
212 | 'chest of drawers': 44,
213 | 'clock': 148,
214 | 'coffee table': 64,
215 | 'column': 42,
216 | 'computer': 74,
217 | 'conveyer belt': 105,
218 | 'counter': 45,
219 | 'countertop': 70,
220 | 'cradle': 117,
221 | 'crt screen': 141,
222 | 'curtain': 18,
223 | 'cushion': 39,
224 | 'desk': 33,
225 | 'dirt track': 91,
226 | 'dishwasher': 129,
227 | 'door': 14,
228 | 'earth': 13,
229 | 'escalator': 96,
230 | 'fan': 139,
231 | 'fence': 32,
232 | 'field': 29,
233 | 'fireplace': 49,
234 | 'flag': 149,
235 | 'floor': 3,
236 | 'flower': 66,
237 | 'food': 120,
238 | 'fountain': 104,
239 | 'glass': 147,
240 | 'grandstand': 51,
241 | 'grass': 9,
242 | 'hill': 68,
243 | 'hood': 133,
244 | 'house': 25,
245 | 'hovel': 79,
246 | 'kitchen island': 73,
247 | 'lake': 128,
248 | 'lamp': 36,
249 | 'land': 94,
250 | 'light': 82,
251 | 'microwave': 124,
252 | 'minibike': 116,
253 | 'mirror': 27,
254 | 'monitor': 143,
255 | 'mountain': 16,
256 | 'ottoman': 97,
257 | 'oven': 118,
258 | 'painting': 22,
259 | 'palm': 72,
260 | 'path': 52,
261 | 'person': 12,
262 | 'pier': 140,
263 | 'pillow': 57,
264 | 'plant': 17,
265 | 'plate': 142,
266 | 'plaything': 108,
267 | 'pole': 93,
268 | 'pool table': 56,
269 | 'poster': 100,
270 | 'pot': 125,
271 | 'radiator': 146,
272 | 'railing': 38,
273 | 'refrigerator': 50,
274 | 'river': 60,
275 | 'road': 6,
276 | 'rock': 34,
277 | 'rug': 28,
278 | 'runway': 54,
279 | 'sand': 46,
280 | 'sconce': 134,
281 | 'screen': 130,
282 | 'screen door': 58,
283 | 'sculpture': 132,
284 | 'sea': 26,
285 | 'seat': 31,
286 | 'shelf': 24,
287 | 'ship': 103,
288 | 'shower': 145,
289 | 'sidewalk': 11,
290 | 'signboard': 43,
291 | 'sink': 47,
292 | 'sky': 2,
293 | 'skyscraper': 48,
294 | 'sofa': 23,
295 | 'stage': 101,
296 | 'stairs': 53,
297 | 'stairway': 59,
298 | 'step': 121,
299 | 'stool': 110,
300 | 'stove': 71,
301 | 'streetlight': 87,
302 | 'swimming pool': 109,
303 | 'swivel chair': 75,
304 | 'table': 15,
305 | 'tank': 122,
306 | 'television receiver': 89,
307 | 'tent': 114,
308 | 'toilet': 65,
309 | 'towel': 81,
310 | 'tower': 84,
311 | 'trade name': 123,
312 | 'traffic light': 136,
313 | 'tray': 137,
314 | 'tree': 4,
315 | 'truck': 83,
316 | 'van': 102,
317 | 'vase': 135,
318 | 'wall': 0,
319 | 'wardrobe': 35,
320 | 'washer': 107,
321 | 'water': 21,
322 | 'waterfall': 113,
323 | 'windowpane': 8}
324 |
--------------------------------------------------------------------------------
/lib/utils/net_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from torch import nn
4 | import numpy as np
5 | import torch.nn.functional
6 | from collections import OrderedDict
7 | from termcolor import colored
8 | import sys
9 | import yaml
10 | from lib.config import cfg
11 |
12 |
13 | def gen_rays_bbox(rays, bounds):
14 | rays_o, rays_d = rays[..., :3], rays[..., 3:6]
15 | norm_d = torch.norm(rays_d, dim=-1, keepdim=True)
16 | viewdir = rays_d / norm_d
17 | viewdir[(viewdir < 1e-5) & (viewdir > -1e-10)] = 1e-5
18 | viewdir[(viewdir > -1e-5) & (viewdir < 1e-10)] = -1e-5
19 |
20 | tmin = (bounds[0:1] - rays_o[:1]) / viewdir
21 | tmax = (bounds[1:2] - rays_o[:1]) / viewdir
22 | t1 = torch.min(tmin, tmax)
23 | t2 = torch.max(tmin, tmax)
24 |
25 | near = torch.max(t1, dim=-1)[0]
26 | far = torch.min(t2, dim=-1)[0]
27 | mask_at_box = near < far
28 | return mask_at_box
29 |
30 | import time
31 | class perf_timer:
32 | def __init__(self, msg="Elapsed time: {}s", logf=lambda x: print(colored(x, 'yellow')), sync_cuda=True, use_ms=False, disabled=False):
33 | self.logf = logf
34 | self.msg = msg
35 | self.sync_cuda = sync_cuda
36 | self.use_ms = use_ms
37 | self.disabled = disabled
38 |
39 | self.loggedtime = None
40 |
41 | def __enter__(self,):
42 | if self.sync_cuda:
43 | torch.cuda.synchronize()
44 | self.start = time.perf_counter()
45 |
46 | def __exit__(self, exc_type, exc_value, traceback):
47 | if self.sync_cuda:
48 | torch.cuda.synchronize()
49 | self.logtime(self.msg)
50 |
51 | def logtime(self, msg=None, logf=None):
52 | if self.disabled:
53 | return
54 | # SAME CLASS, DIFFERENT FUNCTIONALITY, is this good?
55 | # call the logger for timing code sections
56 | if self.sync_cuda:
57 | torch.cuda.synchronize()
58 |
59 | # always remember current time
60 | prev = self.loggedtime
61 | self.loggedtime = time.perf_counter()
62 |
63 | # print it if we've remembered previous time
64 | if prev is not None and msg:
65 | logf = logf or self.logf
66 | diff = self.loggedtime-prev
67 | diff *= 1000 if self.use_ms else 1
68 | logf(msg.format(diff))
69 |
70 | return self.loggedtime
71 |
72 | def sigmoid(x):
73 | y = torch.clamp(x.sigmoid(), min=1e-4, max=1 - 1e-4)
74 | return y
75 |
76 |
77 | def _neg_loss(pred, gt):
78 | ''' Modified focal loss. Exactly the same as CornerNet.
79 | Runs faster and costs a little bit more memory
80 | Arguments:
81 | pred (batch x c x h x w)
82 | gt_regr (batch x c x h x w)
83 | '''
84 | pos_inds = gt.eq(1).float()
85 | neg_inds = gt.lt(1).float()
86 |
87 | neg_weights = torch.pow(1 - gt, 4)
88 |
89 | loss = 0
90 |
91 | pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
92 | neg_loss = torch.log(1 - pred) * torch.pow(pred,
93 | 2) * neg_weights * neg_inds
94 |
95 | num_pos = pos_inds.float().sum()
96 | pos_loss = pos_loss.sum()
97 | neg_loss = neg_loss.sum()
98 |
99 | if num_pos == 0:
100 | loss = loss - neg_loss
101 | else:
102 | loss = loss - (pos_loss + neg_loss) / num_pos
103 | return loss
104 |
105 |
106 | class FocalLoss(nn.Module):
107 | '''nn.Module warpper for focal loss'''
108 | def __init__(self):
109 | super(FocalLoss, self).__init__()
110 | self.neg_loss = _neg_loss
111 |
112 | def forward(self, out, target):
113 | return self.neg_loss(out, target)
114 |
115 |
116 | def smooth_l1_loss(vertex_pred,
117 | vertex_targets,
118 | vertex_weights,
119 | sigma=1.0,
120 | normalize=True,
121 | reduce=True):
122 | """
123 | :param vertex_pred: [b, vn*2, h, w]
124 | :param vertex_targets: [b, vn*2, h, w]
125 | :param vertex_weights: [b, 1, h, w]
126 | :param sigma:
127 | :param normalize:
128 | :param reduce:
129 | :return:
130 | """
131 | b, ver_dim, _, _ = vertex_pred.shape
132 | sigma_2 = sigma**2
133 | vertex_diff = vertex_pred - vertex_targets
134 | diff = vertex_weights * vertex_diff
135 | abs_diff = torch.abs(diff)
136 | smoothL1_sign = (abs_diff < 1. / sigma_2).detach().float()
137 | in_loss = torch.pow(diff, 2) * (sigma_2 / 2.) * smoothL1_sign \
138 | + (abs_diff - (0.5 / sigma_2)) * (1. - smoothL1_sign)
139 |
140 | if normalize:
141 | in_loss = torch.sum(in_loss.view(b, -1), 1) / (
142 | ver_dim * torch.sum(vertex_weights.view(b, -1), 1) + 1e-3)
143 |
144 | if reduce:
145 | in_loss = torch.mean(in_loss)
146 |
147 | return in_loss
148 |
149 |
150 | class SmoothL1Loss(nn.Module):
151 | def __init__(self):
152 | super(SmoothL1Loss, self).__init__()
153 | self.smooth_l1_loss = smooth_l1_loss
154 |
155 | def forward(self,
156 | preds,
157 | targets,
158 | weights,
159 | sigma=1.0,
160 | normalize=True,
161 | reduce=True):
162 | return self.smooth_l1_loss(preds, targets, weights, sigma, normalize,
163 | reduce)
164 |
165 |
166 | class AELoss(nn.Module):
167 | def __init__(self):
168 | super(AELoss, self).__init__()
169 |
170 | def forward(self, ae, ind, ind_mask):
171 | """
172 | ae: [b, 1, h, w]
173 | ind: [b, max_objs, max_parts]
174 | ind_mask: [b, max_objs, max_parts]
175 | obj_mask: [b, max_objs]
176 | """
177 | # first index
178 | b, _, h, w = ae.shape
179 | b, max_objs, max_parts = ind.shape
180 | obj_mask = torch.sum(ind_mask, dim=2) != 0
181 |
182 | ae = ae.view(b, h * w, 1)
183 | seed_ind = ind.view(b, max_objs * max_parts, 1)
184 | tag = ae.gather(1, seed_ind).view(b, max_objs, max_parts)
185 |
186 | # compute the mean
187 | tag_mean = tag * ind_mask
188 | tag_mean = tag_mean.sum(2) / (ind_mask.sum(2) + 1e-4)
189 |
190 | # pull ae of the same object to their mean
191 | pull_dist = (tag - tag_mean.unsqueeze(2)).pow(2) * ind_mask
192 | obj_num = obj_mask.sum(dim=1).float()
193 | pull = (pull_dist.sum(dim=(1, 2)) / (obj_num + 1e-4)).sum()
194 | pull /= b
195 |
196 | # push away the mean of different objects
197 | push_dist = torch.abs(tag_mean.unsqueeze(1) - tag_mean.unsqueeze(2))
198 | push_dist = 1 - push_dist
199 | push_dist = nn.functional.relu(push_dist, inplace=True)
200 | obj_mask = (obj_mask.unsqueeze(1) + obj_mask.unsqueeze(2)) == 2
201 | push_dist = push_dist * obj_mask.float()
202 | push = ((push_dist.sum(dim=(1, 2)) - obj_num) /
203 | (obj_num * (obj_num - 1) + 1e-4)).sum()
204 | push /= b
205 | return pull, push
206 |
207 |
208 | class PolyMatchingLoss(nn.Module):
209 | def __init__(self, pnum):
210 | super(PolyMatchingLoss, self).__init__()
211 |
212 | self.pnum = pnum
213 | batch_size = 1
214 | pidxall = np.zeros(shape=(batch_size, pnum, pnum), dtype=np.int32)
215 | for b in range(batch_size):
216 | for i in range(pnum):
217 | pidx = (np.arange(pnum) + i) % pnum
218 | pidxall[b, i] = pidx
219 |
220 | device = torch.device('cuda')
221 | pidxall = torch.from_numpy(
222 | np.reshape(pidxall, newshape=(batch_size, -1))).to(device)
223 |
224 | self.feature_id = pidxall.unsqueeze_(2).long().expand(
225 | pidxall.size(0), pidxall.size(1), 2).detach()
226 |
227 | def forward(self, pred, gt, loss_type="L2"):
228 | pnum = self.pnum
229 | batch_size = pred.size()[0]
230 | feature_id = self.feature_id.expand(batch_size,
231 | self.feature_id.size(1), 2)
232 | device = torch.device('cuda')
233 |
234 | gt_expand = torch.gather(gt, 1,
235 | feature_id).view(batch_size, pnum, pnum, 2)
236 |
237 | pred_expand = pred.unsqueeze(1)
238 |
239 | dis = pred_expand - gt_expand
240 |
241 | if loss_type == "L2":
242 | dis = (dis**2).sum(3).sqrt().sum(2)
243 | elif loss_type == "L1":
244 | dis = torch.abs(dis).sum(3).sum(2)
245 |
246 | min_dis, min_id = torch.min(dis, dim=1, keepdim=True)
247 | # print(min_id)
248 |
249 | # min_id = torch.from_numpy(min_id.data.cpu().numpy()).to(device)
250 | # min_gt_id_to_gather = min_id.unsqueeze_(2).unsqueeze_(3).long().\
251 | # expand(min_id.size(0), min_id.size(1), gt_expand.size(2), gt_expand.size(3))
252 | # gt_right_order = torch.gather(gt_expand, 1, min_gt_id_to_gather).view(batch_size, pnum, 2)
253 |
254 | return torch.mean(min_dis)
255 |
256 |
257 | class AttentionLoss(nn.Module):
258 | def __init__(self, beta=4, gamma=0.5):
259 | super(AttentionLoss, self).__init__()
260 |
261 | self.beta = beta
262 | self.gamma = gamma
263 |
264 | def forward(self, pred, gt):
265 | num_pos = torch.sum(gt)
266 | num_neg = torch.sum(1 - gt)
267 | alpha = num_neg / (num_pos + num_neg)
268 | edge_beta = torch.pow(self.beta, torch.pow(1 - pred, self.gamma))
269 | bg_beta = torch.pow(self.beta, torch.pow(pred, self.gamma))
270 |
271 | loss = 0
272 | loss = loss - alpha * edge_beta * torch.log(pred) * gt
273 | loss = loss - (1 - alpha) * bg_beta * torch.log(1 - pred) * (1 - gt)
274 | return torch.mean(loss)
275 |
276 |
277 | def _gather_feat(feat, ind, mask=None):
278 | dim = feat.size(2)
279 | ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
280 | feat = feat.gather(1, ind)
281 | if mask is not None:
282 | mask = mask.unsqueeze(2).expand_as(feat)
283 | feat = feat[mask]
284 | feat = feat.view(-1, dim)
285 | return feat
286 |
287 |
288 | def _tranpose_and_gather_feat(feat, ind):
289 | feat = feat.permute(0, 2, 3, 1).contiguous()
290 | feat = feat.view(feat.size(0), -1, feat.size(3))
291 | feat = _gather_feat(feat, ind)
292 | return feat
293 |
294 |
295 | class Ind2dRegL1Loss(nn.Module):
296 | def __init__(self, type='l1'):
297 | super(Ind2dRegL1Loss, self).__init__()
298 | if type == 'l1':
299 | self.loss = torch.nn.functional.l1_loss
300 | elif type == 'smooth_l1':
301 | self.loss = torch.nn.functional.smooth_l1_loss
302 |
303 | def forward(self, output, target, ind, ind_mask):
304 | """ind: [b, max_objs, max_parts]"""
305 | b, max_objs, max_parts = ind.shape
306 | ind = ind.view(b, max_objs * max_parts)
307 | pred = _tranpose_and_gather_feat(output,
308 | ind).view(b, max_objs, max_parts,
309 | output.size(1))
310 | mask = ind_mask.unsqueeze(3).expand_as(pred)
311 | loss = self.loss(pred * mask, target * mask, reduction='sum')
312 | loss = loss / (mask.sum() + 1e-4)
313 | return loss
314 |
315 |
316 | class IndL1Loss1d(nn.Module):
317 | def __init__(self, type='l1'):
318 | super(IndL1Loss1d, self).__init__()
319 | if type == 'l1':
320 | self.loss = torch.nn.functional.l1_loss
321 | elif type == 'smooth_l1':
322 | self.loss = torch.nn.functional.smooth_l1_loss
323 |
324 | def forward(self, output, target, ind, weight):
325 | """ind: [b, n]"""
326 | output = _tranpose_and_gather_feat(output, ind)
327 | weight = weight.unsqueeze(2)
328 | loss = self.loss(output * weight, target * weight, reduction='sum')
329 | loss = loss / (weight.sum() * output.size(2) + 1e-4)
330 | return loss
331 |
332 |
333 | class GeoCrossEntropyLoss(nn.Module):
334 | def __init__(self):
335 | super(GeoCrossEntropyLoss, self).__init__()
336 |
337 | def forward(self, output, target, poly):
338 | output = torch.nn.functional.softmax(output, dim=1)
339 | output = torch.log(torch.clamp(output, min=1e-4))
340 | poly = poly.view(poly.size(0), 4, poly.size(1) // 4, 2)
341 | target = target[..., None, None].expand(poly.size(0), poly.size(1), 1,
342 | poly.size(3))
343 | target_poly = torch.gather(poly, 2, target)
344 | sigma = (poly[:, :, 0] - poly[:, :, 1]).pow(2).sum(2, keepdim=True)
345 | kernel = torch.exp(-(poly - target_poly).pow(2).sum(3) / (sigma / 3))
346 | loss = -(output * kernel.transpose(2, 1)).sum(1).mean()
347 | return loss
348 |
349 |
350 | def load_model(net,
351 | optim,
352 | scheduler,
353 | recorder,
354 | model_dir,
355 | resume=True,
356 | epoch=-1):
357 | if not resume:
358 | os.system('rm -rf {}'.format(model_dir))
359 |
360 | if not os.path.exists(model_dir):
361 | return 0
362 |
363 | pths = [
364 | int(pth.split('.')[0]) for pth in os.listdir(model_dir)
365 | if pth != 'latest.pth' and pth != 'psnr_best.pth' and pth != 'ssim_best.pth' and pth != 'lpips_best.pth'
366 | ]
367 | if len(pths) == 0 and 'latest.pth' not in os.listdir(model_dir):
368 | return 0
369 | if epoch == -1:
370 | if 'latest.pth' in os.listdir(model_dir):
371 | pth = 'latest'
372 | else:
373 | pth = max(pths)
374 | else:
375 | pth = epoch
376 | print('load model: {}'.format(os.path.join(model_dir,
377 | '{}.pth'.format(pth))))
378 | pretrained_model = torch.load(
379 | os.path.join(model_dir, '{}.pth'.format(pth)), 'cpu')
380 | net.load_state_dict(pretrained_model['net'])
381 | if 'optim' in pretrained_model:
382 | # optim.load_state_dict(pretrained_model['optim']) ###ft
383 | scheduler.load_state_dict(pretrained_model['scheduler'])
384 | recorder.load_state_dict(pretrained_model['recorder'])
385 | return pretrained_model['epoch'] + 1
386 | else:
387 | return 0
388 |
389 |
390 | def save_model(net, optim, scheduler, recorder, model_dir, epoch, custom=None, last=False):
391 | os.system('mkdir -p {}'.format(model_dir))
392 | model = {
393 | 'net': net.state_dict(),
394 | 'optim': optim.state_dict(),
395 | 'scheduler': scheduler.state_dict(),
396 | 'recorder': recorder.state_dict(),
397 | 'epoch': epoch
398 | }
399 | if last:
400 | torch.save(model, os.path.join(model_dir, 'latest.pth'))
401 | else:
402 | torch.save(model, os.path.join(model_dir, '{}.pth'.format(epoch)))
403 | if custom is not None:
404 | torch.save(model, os.path.join(model_dir, f'{custom}.pth'))
405 |
406 | # remove previous pretrained model if the number of models is too big
407 | pths = [
408 | int(pth.split('.')[0]) for pth in os.listdir(model_dir)
409 | if pth != 'latest.pth' and pth != 'psnr_best.pth' and pth != 'ssim_best.pth' and pth != 'lpips_best.pth'
410 | ]
411 | # if len(pths) <= 5:
412 | # return
413 | # os.system('rm {}'.format(
414 | # os.path.join(model_dir, '{}.pth'.format(min(pths)))))
415 |
416 |
417 | def load_network(net, model_dir, resume=True, epoch=-1, custom=None, strict=True):
418 | if not resume:
419 | return 0
420 | if not os.path.exists(model_dir):
421 | print(colored('pretrained model does not exist', 'red'))
422 | return 0
423 |
424 | if os.path.isdir(model_dir):
425 | pths = [
426 | int(pth.split('.')[0]) for pth in os.listdir(model_dir)
427 | if pth != 'latest.pth' and pth != 'psnr_best.pth' and pth != 'ssim_best.pth' and pth != 'lpips_best.pth'
428 | ]
429 | if len(pths) == 0 and 'latest.pth' not in os.listdir(model_dir):
430 | return 0
431 | if epoch == -1:
432 | if custom is not None and custom+'.pth' in os.listdir(model_dir):
433 | pth = custom
434 | elif 'latest.pth' in os.listdir(model_dir):
435 | pth = 'latest'
436 | else:
437 | pth = max(pths)
438 | else:
439 | pth = epoch
440 | model_path = os.path.join(model_dir, '{}.pth'.format(pth))
441 | else:
442 | model_path = model_dir
443 | print('load model: {}'.format(model_path))
444 | pretrained_model = torch.load(model_path)
445 | net.load_state_dict(pretrained_model['net'], strict=strict)
446 | if 'epoch' in pretrained_model:
447 | return pretrained_model['epoch'] + 1
448 | else:
449 | return 0
450 |
451 |
452 | def remove_net_prefix(net, prefix):
453 | net_ = OrderedDict()
454 | for k in net.keys():
455 | if k.startswith(prefix):
456 | net_[k[len(prefix):]] = net[k]
457 | else:
458 | net_[k] = net[k]
459 | return net_
460 |
461 |
462 | def add_net_prefix(net, prefix):
463 | net_ = OrderedDict()
464 | for k in net.keys():
465 | net_[prefix + k] = net[k]
466 | return net_
467 |
468 |
469 | def replace_net_prefix(net, orig_prefix, prefix):
470 | net_ = OrderedDict()
471 | for k in net.keys():
472 | if k.startswith(orig_prefix):
473 | net_[prefix + k[len(orig_prefix):]] = net[k]
474 | else:
475 | net_[k] = net[k]
476 | return net_
477 |
478 |
479 | def remove_net_layer(net, layers):
480 | keys = list(net.keys())
481 | for k in keys:
482 | for layer in layers:
483 | if k.startswith(layer):
484 | del net[k]
485 | return net
486 |
487 | def save_trained_config(cfg):
488 | if not cfg.resume:
489 | os.system('rm -rf ' + cfg.trained_config_dir+'/*')
490 | os.system('mkdir -p ' + cfg.trained_config_dir)
491 | train_cmd = ' '.join(sys.argv)
492 | train_cmd_path = os.path.join(cfg.trained_config_dir, 'train_cmd.txt')
493 | train_config_path = os.path.join(cfg.trained_config_dir, 'train_config.yaml')
494 | open(train_cmd_path, 'w').write(train_cmd)
495 | yaml.dump(cfg, open(train_config_path, 'w'))
496 |
497 | def load_pretrain(net, model_dir):
498 | model_dir = os.path.join(cfg.workspace, 'trained_model', cfg.task, model_dir)
499 | if not os.path.exists(model_dir):
500 | return 1
501 |
502 | pths = [int(pth.split('.')[0]) for pth in os.listdir(model_dir) if pth != 'latest.pth']
503 | if len(pths) == 0 and 'latest.pth' not in os.listdir(model_dir):
504 | return 1
505 |
506 | if 'latest.pth' in os.listdir(model_dir):
507 | pth = 'latest'
508 | else:
509 | pth = max(pths)
510 | print('Load pretrain model: {}'.format(os.path.join(model_dir, '{}.pth'.format(pth))))
511 | pretrained_model = torch.load(os.path.join(model_dir, '{}.pth'.format(pth)), 'cpu')
512 | net.load_state_dict(pretrained_model['net'])
513 | return 0
514 |
515 | def save_pretrain(net, task, model_dir):
516 | model_dir = os.path.join('data/trained_model', task, model_dir)
517 | os.system('mkdir -p ' + model_dir)
518 | model = {'net': net.state_dict()}
519 | torch.save(model, os.path.join(model_dir, 'latest.pth'))
520 |
521 |
522 |
--------------------------------------------------------------------------------
/lib/utils/optimizer/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | from bisect import bisect_right
2 | from collections import Counter
3 |
4 | import torch
5 |
6 |
7 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
8 | def __init__(
9 | self,
10 | optimizer,
11 | milestones,
12 | gamma=0.1,
13 | warmup_factor=1.0 / 3,
14 | warmup_iters=5,
15 | warmup_method="linear",
16 | last_epoch=-1,
17 | ):
18 | if not list(milestones) == sorted(milestones):
19 | raise ValueError(
20 | "Milestones should be a list of" " increasing integers. Got {}",
21 | milestones,
22 | )
23 |
24 | if warmup_method not in ("constant", "linear"):
25 | raise ValueError(
26 | "Only 'constant' or 'linear' warmup_method accepted"
27 | "got {}".format(warmup_method)
28 | )
29 | self.milestones = milestones
30 | self.gamma = gamma
31 | self.warmup_factor = warmup_factor
32 | self.warmup_iters = warmup_iters
33 | self.warmup_method = warmup_method
34 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
35 |
36 | def get_lr(self):
37 | warmup_factor = 1
38 | if self.last_epoch < self.warmup_iters:
39 | if self.warmup_method == "constant":
40 | warmup_factor = self.warmup_factor
41 | elif self.warmup_method == "linear":
42 | alpha = float(self.last_epoch) / self.warmup_iters
43 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
44 | return [
45 | base_lr
46 | * warmup_factor
47 | * self.gamma ** bisect_right(self.milestones, self.last_epoch)
48 | for base_lr in self.base_lrs
49 | ]
50 |
51 |
52 | class MultiStepLR(torch.optim.lr_scheduler._LRScheduler):
53 |
54 | def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1):
55 | self.milestones = Counter(milestones)
56 | self.gamma = gamma
57 | super(MultiStepLR, self).__init__(optimizer, last_epoch)
58 |
59 | def get_lr(self):
60 | if self.last_epoch not in self.milestones:
61 | return [group['lr'] for group in self.optimizer.param_groups]
62 | return [group['lr'] * self.gamma ** self.milestones[self.last_epoch]
63 | for group in self.optimizer.param_groups]
64 |
65 |
66 | class ExponentialLR(torch.optim.lr_scheduler._LRScheduler):
67 |
68 | def __init__(self, optimizer, decay_epochs, gamma=0.1, last_epoch=-1):
69 | self.decay_epochs = decay_epochs
70 | self.gamma = gamma
71 | super(ExponentialLR, self).__init__(optimizer, last_epoch)
72 |
73 | def get_lr(self):
74 | return [base_lr * self.gamma ** (self.last_epoch / self.decay_epochs)
75 | for base_lr in self.base_lrs]
76 |
--------------------------------------------------------------------------------
/lib/utils/optimizer/radam.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer, required
4 |
5 |
6 | class RAdam(Optimizer):
7 |
8 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
9 | if not 0.0 <= lr:
10 | raise ValueError("Invalid learning rate: {}".format(lr))
11 | if not 0.0 <= eps:
12 | raise ValueError("Invalid epsilon value: {}".format(eps))
13 | if not 0.0 <= betas[0] < 1.0:
14 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
15 | if not 0.0 <= betas[1] < 1.0:
16 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
17 |
18 | self.degenerated_to_sgd = degenerated_to_sgd
19 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
20 | for param in params:
21 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
22 | param['buffer'] = [[None, None, None] for _ in range(10)]
23 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)])
24 | super(RAdam, self).__init__(params, defaults)
25 |
26 | def __setstate__(self, state):
27 | super(RAdam, self).__setstate__(state)
28 |
29 | def step(self, closure=None):
30 |
31 | loss = None
32 | if closure is not None:
33 | loss = closure()
34 |
35 | for group in self.param_groups:
36 |
37 | for p in group['params']:
38 | if p.grad is None:
39 | continue
40 | grad = p.grad.data.float()
41 | if grad.is_sparse:
42 | raise RuntimeError('RAdam does not support sparse gradients')
43 |
44 | p_data_fp32 = p.data.float()
45 |
46 | state = self.state[p]
47 |
48 | if len(state) == 0:
49 | state['step'] = 0
50 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
51 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
52 | else:
53 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
54 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
55 |
56 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
57 | beta1, beta2 = group['betas']
58 |
59 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
60 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
61 |
62 | state['step'] += 1
63 | buffered = group['buffer'][int(state['step'] % 10)]
64 | if state['step'] == buffered[0]:
65 | N_sma, step_size = buffered[1], buffered[2]
66 | else:
67 | buffered[0] = state['step']
68 | beta2_t = beta2 ** state['step']
69 | N_sma_max = 2 / (1 - beta2) - 1
70 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
71 | buffered[1] = N_sma
72 |
73 | # more conservative since it's an approximated value
74 | if N_sma >= 5:
75 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
76 | elif self.degenerated_to_sgd:
77 | step_size = 1.0 / (1 - beta1 ** state['step'])
78 | else:
79 | step_size = -1
80 | buffered[2] = step_size
81 |
82 | # more conservative since it's an approximated value
83 | if N_sma >= 5:
84 | if group['weight_decay'] != 0:
85 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
86 | denom = exp_avg_sq.sqrt().add_(group['eps'])
87 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
88 | p.data.copy_(p_data_fp32)
89 | elif step_size > 0:
90 | if group['weight_decay'] != 0:
91 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
92 | p_data_fp32.add_(-step_size * group['lr'], exp_avg)
93 | p.data.copy_(p_data_fp32)
94 |
95 | return loss
96 |
97 | class PlainRAdam(Optimizer):
98 |
99 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
100 | if not 0.0 <= lr:
101 | raise ValueError("Invalid learning rate: {}".format(lr))
102 | if not 0.0 <= eps:
103 | raise ValueError("Invalid epsilon value: {}".format(eps))
104 | if not 0.0 <= betas[0] < 1.0:
105 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
106 | if not 0.0 <= betas[1] < 1.0:
107 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
108 |
109 | self.degenerated_to_sgd = degenerated_to_sgd
110 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
111 |
112 | super(PlainRAdam, self).__init__(params, defaults)
113 |
114 | def __setstate__(self, state):
115 | super(PlainRAdam, self).__setstate__(state)
116 |
117 | def step(self, closure=None):
118 |
119 | loss = None
120 | if closure is not None:
121 | loss = closure()
122 |
123 | for group in self.param_groups:
124 |
125 | for p in group['params']:
126 | if p.grad is None:
127 | continue
128 | grad = p.grad.data.float()
129 | if grad.is_sparse:
130 | raise RuntimeError('RAdam does not support sparse gradients')
131 |
132 | p_data_fp32 = p.data.float()
133 |
134 | state = self.state[p]
135 |
136 | if len(state) == 0:
137 | state['step'] = 0
138 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
139 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
140 | else:
141 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
142 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
143 |
144 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
145 | beta1, beta2 = group['betas']
146 |
147 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
148 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
149 |
150 | state['step'] += 1
151 | beta2_t = beta2 ** state['step']
152 | N_sma_max = 2 / (1 - beta2) - 1
153 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
154 |
155 |
156 | # more conservative since it's an approximated value
157 | if N_sma >= 5:
158 | if group['weight_decay'] != 0:
159 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
160 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
161 | denom = exp_avg_sq.sqrt().add_(group['eps'])
162 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
163 | p.data.copy_(p_data_fp32)
164 | elif self.degenerated_to_sgd:
165 | if group['weight_decay'] != 0:
166 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
167 | step_size = group['lr'] / (1 - beta1 ** state['step'])
168 | p_data_fp32.add_(-step_size, exp_avg)
169 | p.data.copy_(p_data_fp32)
170 |
171 | return loss
172 |
173 |
174 | class AdamW(Optimizer):
175 |
176 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0):
177 | if not 0.0 <= lr:
178 | raise ValueError("Invalid learning rate: {}".format(lr))
179 | if not 0.0 <= eps:
180 | raise ValueError("Invalid epsilon value: {}".format(eps))
181 | if not 0.0 <= betas[0] < 1.0:
182 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
183 | if not 0.0 <= betas[1] < 1.0:
184 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
185 |
186 | defaults = dict(lr=lr, betas=betas, eps=eps,
187 | weight_decay=weight_decay, warmup = warmup)
188 | super(AdamW, self).__init__(params, defaults)
189 |
190 | def __setstate__(self, state):
191 | super(AdamW, self).__setstate__(state)
192 |
193 | def step(self, closure=None):
194 | loss = None
195 | if closure is not None:
196 | loss = closure()
197 |
198 | for group in self.param_groups:
199 |
200 | for p in group['params']:
201 | if p.grad is None:
202 | continue
203 | grad = p.grad.data.float()
204 | if grad.is_sparse:
205 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
206 |
207 | p_data_fp32 = p.data.float()
208 |
209 | state = self.state[p]
210 |
211 | if len(state) == 0:
212 | state['step'] = 0
213 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
214 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
215 | else:
216 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
217 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
218 |
219 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
220 | beta1, beta2 = group['betas']
221 |
222 | state['step'] += 1
223 |
224 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
225 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
226 |
227 | denom = exp_avg_sq.sqrt().add_(group['eps'])
228 | bias_correction1 = 1 - beta1 ** state['step']
229 | bias_correction2 = 1 - beta2 ** state['step']
230 |
231 | if group['warmup'] > state['step']:
232 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup']
233 | else:
234 | scheduled_lr = group['lr']
235 |
236 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1
237 |
238 | if group['weight_decay'] != 0:
239 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32)
240 |
241 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
242 |
243 | p.data.copy_(p_data_fp32)
244 |
245 | return loss
246 |
247 |
--------------------------------------------------------------------------------
/lib/utils/ply_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import numpy as np
3 | from PIL import Image
4 | from plyfile import PlyData, PlyElement
5 |
6 | def read_pfm(filename):
7 | file = open(filename, 'rb')
8 | color = None
9 | width = None
10 | height = None
11 | scale = None
12 | endian = None
13 |
14 | header = file.readline().decode('utf-8').rstrip()
15 | if header == 'PF':
16 | color = True
17 | elif header == 'Pf':
18 | color = False
19 | else:
20 | raise Exception('Not a PFM file.')
21 |
22 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
23 | if dim_match:
24 | width, height = map(int, dim_match.groups())
25 | else:
26 | raise Exception('Malformed PFM header.')
27 |
28 | scale = float(file.readline().rstrip())
29 | if scale < 0: # little-endian
30 | endian = '<'
31 | scale = -scale
32 | else:
33 | endian = '>' # big-endian
34 |
35 | data = np.fromfile(file, endian + 'f')
36 | shape = (height, width, 3) if color else (height, width)
37 |
38 | data = np.reshape(data, shape)
39 | data = np.flipud(data)
40 | file.close()
41 | return data, scale
42 |
43 | def read_camera_parameters(filename):
44 | with open(filename) as f:
45 | lines = f.readlines()
46 | lines = [line.rstrip() for line in lines]
47 | # extrinsics: line [1,5), 4x4 matrix
48 | extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
49 | # intrinsics: line [7-10), 3x3 matrix
50 | intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
51 | return intrinsics, extrinsics
52 |
53 | def read_img(filename):
54 | img = Image.open(filename)
55 | # scale 0~255 to 0~1
56 | np_img = np.array(img, dtype=np.float32) / 255.
57 | return np_img
58 |
59 |
60 | def storePly(path, xyz, rgb):
61 | # Define the dtype for the structured array
62 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
63 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
64 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
65 |
66 | normals = np.zeros_like(xyz)
67 |
68 | elements = np.empty(xyz.shape[0], dtype=dtype)
69 | attributes = np.concatenate((xyz, normals, rgb), axis=1)
70 | elements[:] = list(map(tuple, attributes))
71 |
72 | # Create the PlyData object and write to file
73 | vertex_element = PlyElement.describe(elements, 'vertex')
74 | ply_data = PlyData([vertex_element])
75 | ply_data.write(path)
--------------------------------------------------------------------------------
/lib/utils/rend_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 |
4 | def normalize(x):
5 | return x / np.linalg.norm(x)
6 |
7 | def ptstocam(pts, c2w):
8 | tt = np.matmul(c2w[:3, :3].T, (pts-c2w[:3, 3])[..., np.newaxis])[..., 0]
9 | return tt
10 |
11 | def viewmatrix(z, up, pos):
12 | vec2 = normalize(z)
13 | vec0_avg = up
14 | vec1 = normalize(np.cross(vec2, vec0_avg))
15 | vec0 = normalize(np.cross(vec1, vec2))
16 | m = np.stack([vec0, vec1, vec2, pos], 1)
17 | return m
18 |
19 | def gen_path(RT, center=None, num_views=100):
20 | lower_row = np.array([[0., 0., 0., 1.]])
21 |
22 | # transfer RT to camera_to_world matrix
23 | RT = np.array(RT)
24 | RT[:] = np.linalg.inv(RT[:])
25 |
26 | RT = np.concatenate([RT[:, :, 1:2], RT[:, :, 0:1],
27 | -RT[:, :, 2:3], RT[:, :, 3:4]], 2)
28 |
29 | up = normalize(RT[:, :3, 0].sum(0)) # average up vector
30 | z = normalize(RT[0, :3, 2])
31 | vec1 = normalize(np.cross(z, up))
32 | vec2 = normalize(np.cross(up, vec1))
33 | z_off = 0
34 |
35 | if center is None:
36 | center = RT[:, :3, 3].mean(0)
37 | z_off = 1.3
38 |
39 | c2w = np.stack([up, vec1, vec2, center], 1)
40 |
41 | # get radii for spiral path
42 | tt = ptstocam(RT[:, :3, 3], c2w).T
43 | rads = np.percentile(np.abs(tt), 80, -1)
44 | rads = rads * 1.3
45 | rads = np.array(list(rads) + [1.])
46 |
47 | render_w2c = []
48 | for theta in np.linspace(0., 2 * np.pi, num_views + 1)[:-1]:
49 | # camera position
50 | cam_pos = np.array([0, np.sin(theta), np.cos(theta), 1] * rads)
51 | cam_pos_world = np.dot(c2w[:3, :4], cam_pos)
52 | # z axis
53 | z = normalize(cam_pos_world -
54 | np.dot(c2w[:3, :4], np.array([z_off, 0, 0, 1.])))
55 | # vector -> 3x4 matrix (camera_to_world)
56 | mat = viewmatrix(z, up, cam_pos_world)
57 |
58 | mat = np.concatenate([mat[:, 1:2], mat[:, 0:1],
59 | -mat[:, 2:3], mat[:, 3:4]], 1)
60 | mat = np.concatenate([mat, lower_row], 0)
61 | mat = np.linalg.inv(mat)
62 | render_w2c.append(mat)
63 |
64 | return render_w2c
65 |
66 | def create_center_radius(center, radius=5., up='y', ranges=[0, 360, 36], angle_x=0, **kwargs):
67 | center = np.array(center).reshape(1, 3)
68 | thetas = np.deg2rad(np.linspace(*ranges))
69 | st = np.sin(thetas)
70 | ct = np.cos(thetas)
71 | zero = np.zeros_like(st)
72 | Rotx = cv2.Rodrigues(np.deg2rad(angle_x) * np.array([1., 0., 0.]))[0]
73 | if up == 'z':
74 | center = np.stack([radius*ct, radius*st, zero], axis=1) + center
75 | R = np.stack([-st, ct, zero, zero, zero, zero-1, -ct, -st, zero], axis=-1)
76 | elif up == 'y':
77 | center = np.stack([radius*ct, zero, radius*st, ], axis=1) + center
78 | R = np.stack([
79 | +st, zero, -ct,
80 | zero, zero-1, zero,
81 | -ct, zero, -st], axis=-1)
82 | R = R.reshape(-1, 3, 3)
83 | R = np.einsum('ab,fbc->fac', Rotx, R)
84 | center = center.reshape(-1, 3, 1)
85 | T = - R @ center
86 | RT = np.dstack([R, T])
87 | return RT
88 |
89 |
90 | def gen_path(RT, center=None, radius=3.2, up='z', ranges=[0, 360, 90], angle_x=27, **kwargs):
91 | ranges = [0, 360, kwargs['num_views']]
92 | c2ws = np.linalg.inv(RT[:])
93 | if center is None:
94 | center = c2ws[:, :3, 3].mean(0)
95 | center[0], center[1] = 0., 0.
96 |
97 | RTs = []
98 | center = np.array(center).reshape(1, 3)
99 | thetas = np.deg2rad(np.linspace(*ranges))
100 | st = np.sin(thetas)
101 | ct = np.cos(thetas)
102 | zero = np.zeros_like(st)
103 | Rotx = cv2.Rodrigues(np.deg2rad(angle_x) * np.array([1., 0., 0.]))[0]
104 | if up == 'z':
105 | center = np.stack([radius*ct, radius*st, zero], axis=1) + center
106 | R = np.stack([-st, ct, zero, zero, zero, zero-1, -ct, -st, zero], axis=-1)
107 | elif up == 'y':
108 | center = np.stack([radius*ct, zero, radius*st, ], axis=1) + center
109 | R = np.stack([
110 | +st, zero, -ct,
111 | zero, zero-1, zero,
112 | -ct, zero, -st], axis=-1)
113 | R = R.reshape(-1, 3, 3)
114 | R = np.einsum('ab,fbc->fac', Rotx, R)
115 | center = center.reshape(-1, 3, 1)
116 | T = - R @ center
117 | RT = np.dstack([R, T])
118 | RT_bottom = np.zeros_like(RT[:, :1])
119 | RT_bottom[:, :, 3] = 1
120 | # __import__('ipdb').set_trace()
121 | ext = np.concatenate([RT, RT_bottom], axis=1)
122 | c2w = np.linalg.inv(ext)
123 | # __import__('ipdb').set_trace()
124 | # import matplotlib.pyplot as plt
125 | # plt.plot(c2ws[:, 0, 3], c2ws[:, 1, 3], '.')
126 | # plt.plot(c2w[:, 0, 3], c2w[:, 1, 3], '.')
127 | # plt.show()
128 | return ext
129 |
130 | def gen_nerf_path(c2ws, depth_ranges, rads_scale=.5, N_views=60):
131 | c2w = poses_avg(c2ws)
132 | up = normalize(c2ws[:, :3, 1].sum(0))
133 |
134 | close_depth, inf_depth = depth_ranges
135 | dt = .75
136 | mean_dz = 1./(( (1.-dt)/close_depth + dt/inf_depth ))
137 | focal = mean_dz
138 |
139 | shrink_factor = .8
140 | zdelta = close_depth * .2
141 | tt = c2ws[:, :3, 3] = c2w[:3, 3][None]
142 | rads = np.percentile(np.abs(tt), 70, 0)*rads_scale
143 |
144 | render_poses = render_path_spiral(c2w, up, rads, focal, zdelta, zrate=.5, N=N_views)
145 | return render_poses
146 |
147 | def poses_avg(poses):
148 | center = poses[:, :3, 3].mean(0)
149 | vec2 = normalize(poses[:, :3, 2].sum(0))
150 | up = poses[:, :3, 1].sum(0)
151 | c2w = viewmatrix(vec2, up, center)
152 | return c2w
153 |
154 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=120):
155 | render_poses = []
156 | rads = np.array(list(rads) + [1.])
157 |
158 | for theta in np.linspace(0., 2. * np.pi * N_rots, N+1)[:-1]:
159 | c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads)
160 | z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.])))
161 | render_poses.append(viewmatrix(z, up, c))
162 | return render_poses
163 |
--------------------------------------------------------------------------------
/lib/utils/vis_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 |
4 | def get_bound_corners(bounds):
5 | min_x, min_y, min_z = bounds[0]
6 | max_x, max_y, max_z = bounds[1]
7 | corners_3d = np.array([
8 | [min_x, min_y, min_z],
9 | [min_x, min_y, max_z],
10 | [min_x, max_y, min_z],
11 | [min_x, max_y, max_z],
12 | [max_x, min_y, min_z],
13 | [max_x, min_y, max_z],
14 | [max_x, max_y, min_z],
15 | [max_x, max_y, max_z],
16 | ])
17 | return corners_3d
18 |
19 | def project(xyz, K, RT):
20 | """
21 | xyz: [N, 3]
22 | K: [3, 3]
23 | RT: [3, 4]
24 | """
25 | xyz = np.dot(xyz, RT[:, :3].T) + RT[:, 3:].T
26 |
27 | xyz = np.dot(xyz, K.T)
28 | xy = xyz[:, :2] / xyz[:, 2:]
29 | return xy
30 |
31 | row_col_ = {
32 | 2: (2, 1),
33 | 7: (2, 4),
34 | 8: (2, 4),
35 | 9: (3, 3),
36 | 26: (4, 7)
37 | }
38 |
39 | row_col_square = {
40 | 2: (2, 1),
41 | 7: (3, 3),
42 | 8: (3, 3),
43 | 9: (3, 3),
44 | 26: (5, 5)
45 | }
46 |
47 | def get_row_col(l, square):
48 | if square and l in row_col_square.keys():
49 | return row_col_square[l]
50 | if l in row_col_.keys():
51 | return row_col_[l]
52 | else:
53 | from math import sqrt
54 | row = int(sqrt(l) + 0.5)
55 | col = int(l/ row + 0.5)
56 | if row*col col:
59 | row, col = col, row
60 | return row, col
61 |
62 | def merge(images, row=-1, col=-1, resize=False, ret_range=False, square=False, **kwargs):
63 | if row == -1 and col == -1:
64 | row, col = get_row_col(len(images), square)
65 | height = images[0].shape[0]
66 | width = images[0].shape[1]
67 | # special case
68 | if height > width:
69 | if len(images) == 3:
70 | row, col = 1, 3
71 | if len(images[0].shape) > 2:
72 | ret_img = np.zeros((height * row, width * col, images[0].shape[2]), dtype=np.uint8) + 255
73 | else:
74 | ret_img = np.zeros((height * row, width * col), dtype=np.uint8) + 255
75 | ranges = []
76 | for i in range(row):
77 | for j in range(col):
78 | if i*col + j >= len(images):
79 | break
80 | img = images[i * col + j]
81 | # resize the image size
82 | img = cv2.resize(img, (width, height))
83 | ret_img[height * i: height * (i+1), width * j: width * (j+1)] = img
84 | ranges.append((width*j, height*i, width*(j+1), height*(i+1)))
85 | if resize:
86 | min_height = 1000
87 | if ret_img.shape[0] > min_height:
88 | scale = min_height/ret_img.shape[0]
89 | ret_img = cv2.resize(ret_img, None, fx=scale, fy=scale)
90 | if ret_range:
91 | return ret_img, ranges
92 | return ret_img
93 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pyyaml
2 | opencv-python
3 | imgaug
4 | plyfile
5 | tqdm
6 | kornia
7 | ipdb
8 | lpips
9 | tensorboardX
10 | glfw
11 | pyglm
12 | pyopengl
13 | imgui
14 | termcolor
15 | trimesh
16 | scikit-image==0.19.0
17 | imageio==2.27.0
18 | imageio[ffmpeg]
19 | imageio[pyav]
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | from lib.config import cfg, args
2 | from lib.utils.ply_utils import *
3 | import numpy as np
4 | import os
5 | import glob
6 |
7 | def run_dataset():
8 | from lib.datasets import make_data_loader
9 | import tqdm
10 |
11 | cfg.train.num_workers = 0
12 | data_loader = make_data_loader(cfg, is_train=False)
13 | for batch in tqdm.tqdm(data_loader):
14 | pass
15 |
16 | def run_network():
17 | from lib.networks import make_network
18 | from lib.datasets import make_data_loader
19 | from lib.utils.net_utils import load_network
20 | from lib.utils.data_utils import to_cuda
21 | import tqdm
22 | import torch
23 | import time
24 |
25 | network = make_network(cfg).cuda()
26 | load_network(network, cfg.trained_model_dir, epoch=cfg.test.epoch)
27 | network.eval()
28 |
29 | data_loader = make_data_loader(cfg, is_train=False)
30 | total_time = 0
31 | for batch in tqdm.tqdm(data_loader):
32 | batch = to_cuda(batch)
33 | with torch.no_grad():
34 | torch.cuda.synchronize()
35 | start = time.time()
36 | network(batch)
37 | torch.cuda.synchronize()
38 | total_time += time.time() - start
39 | print(total_time / len(data_loader))
40 |
41 | def run_evaluate():
42 | from lib.datasets import make_data_loader
43 | from lib.evaluators import make_evaluator
44 | import tqdm
45 | import torch
46 | from lib.networks import make_network
47 | from lib.utils import net_utils
48 | import time
49 |
50 | network = make_network(cfg).cuda()
51 | net_utils.load_network(network,
52 | cfg.trained_model_dir,
53 | resume=cfg.resume,
54 | epoch=cfg.test.epoch)
55 | network.eval()
56 |
57 | data_loader = make_data_loader(cfg, is_train=False)
58 | evaluator = make_evaluator(cfg)
59 | net_time = []
60 | for batch in tqdm.tqdm(data_loader):
61 | for k in batch:
62 | if k != 'meta':
63 | if k == 'rendering_video_meta':
64 | for i in range(len(batch[k])):
65 | for v in batch[k][i]:
66 | batch[k][i][v] = batch[k][i][v].cuda()
67 | else:
68 | batch[k] = batch[k].cuda()
69 | if cfg.save_video:
70 | with torch.no_grad():
71 | network(batch)
72 | else:
73 | with torch.no_grad():
74 | torch.cuda.synchronize()
75 | start_time = time.time()
76 | output = network(batch)
77 | torch.cuda.synchronize()
78 | end_time = time.time()
79 | net_time.append(end_time - start_time)
80 | evaluator.evaluate(output, batch)
81 |
82 | if not cfg.save_video:
83 | evaluator.summarize()
84 | if len(net_time) > 1:
85 | # print('net_time: ', np.mean(net_time[1:]))
86 | print('FPS: ', 1./np.mean(net_time[1:]))
87 | else:
88 | # print('net_time: ', np.mean(net_time))
89 | print('FPS: ', 1./np.mean(net_time))
90 |
91 |
92 | if cfg.save_ply:
93 | dataset_name = cfg.train_dataset_module.split('.')[-2]
94 | ply_dir = os.path.join(cfg.result_dir, 'pointclouds', dataset_name)
95 | for item in os.listdir(ply_dir):
96 | data_dir = os.path.join(ply_dir, item)
97 | img_dir = os.path.join(data_dir, 'images')
98 | depth_dir = os.path.join(data_dir, 'depth')
99 | cam_dir = os.path.join(data_dir, 'cam')
100 | img_ls = glob.glob(os.path.join(img_dir, '*.png'))
101 | img_name = [os.path.basename(im).split('.')[0] for im in img_ls]
102 |
103 | # for the final point cloud
104 | vertexs = []
105 | vertex_colors = []
106 |
107 | for name in img_name:
108 | ref_name = name
109 |
110 | ref_intrinsics, ref_extrinsics = read_camera_parameters(os.path.join(cam_dir, ref_name+'.txt'))
111 | ref_img = read_img(os.path.join(img_dir, ref_name+'.png'))
112 | ref_depth_est = read_pfm(os.path.join(depth_dir, ref_name+'.pfm'))[0]
113 |
114 | height, width = ref_depth_est.shape[:2]
115 | x, y = np.meshgrid(np.arange(0, width), np.arange(0, height))
116 | x, y = x.reshape(-1), y.reshape(-1)
117 | depth = ref_depth_est.reshape(-1)
118 | color = ref_img.reshape(-1,3)
119 | xyz_ref = np.matmul(np.linalg.inv(ref_intrinsics),
120 | np.vstack((x, y, np.ones_like(x))) * depth)
121 | xyz_world = np.matmul(np.linalg.inv(ref_extrinsics),
122 | np.vstack((xyz_ref, np.ones_like(x))))[:3]
123 | vertexs.append(xyz_world.transpose((1, 0)))
124 | vertex_colors.append((color * 255).astype(np.uint8))
125 | vertexs = np.concatenate(vertexs, axis=0)
126 | vertex_colors = np.concatenate(vertex_colors, axis=0)
127 | scene = os.path.basename(data_dir)
128 | ply_path = os.path.join(data_dir, f'{scene}.ply')
129 | print(f'saving {ply_path}')
130 | storePly(ply_path, vertexs, vertex_colors)
131 |
132 | ## point cloud --> mesh
133 | import open3d as o3d
134 | import trimesh
135 | pcd = o3d.io.read_point_cloud(ply_path)
136 | pcd.estimate_normals()
137 |
138 | # estimate radius for rolling ball
139 | distances = pcd.compute_nearest_neighbor_distance()
140 | avg_dist = np.mean(distances)
141 | radius = 1.5 * avg_dist
142 |
143 | mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
144 | pcd,
145 | o3d.utility.DoubleVector([radius, radius * 2]))
146 |
147 | # create the triangular mesh with the vertices and faces from open3d
148 | tri_mesh = trimesh.Trimesh(np.asarray(mesh.vertices), np.asarray(mesh.triangles),
149 | vertex_normals=np.asarray(mesh.vertex_normals))
150 |
151 | trimesh.convex.is_convex(tri_mesh)
152 | ply_path = os.path.join(data_dir, f'{scene}_mesh.ply')
153 | trimesh.exchange.export.export_mesh(tri_mesh, ply_path)
154 |
155 |
156 |
157 | if __name__ == '__main__':
158 | globals()['run_' + args.type]()
159 |
--------------------------------------------------------------------------------
/train_net.py:
--------------------------------------------------------------------------------
1 | from lib.config import cfg, args
2 | from lib.networks import make_network
3 | from lib.train import make_trainer, make_optimizer, make_lr_scheduler, make_recorder, set_lr_scheduler
4 | from lib.datasets import make_data_loader
5 | from lib.utils.net_utils import load_model, save_model, load_network, save_trained_config, load_pretrain
6 | from lib.evaluators import make_evaluator
7 | import torch.multiprocessing
8 | import torch
9 | import torch.distributed as dist
10 | import os
11 | # torch.autograd.set_detect_anomaly(True)
12 |
13 | if cfg.fix_random:
14 | torch.manual_seed(0)
15 | torch.backends.cudnn.deterministic = True
16 | torch.backends.cudnn.benchmark = False
17 |
18 |
19 | def train(cfg, network):
20 | train_loader = make_data_loader(cfg,
21 | is_train=True,
22 | is_distributed=cfg.distributed,
23 | max_iter=cfg.ep_iter)
24 | if cfg.skip_eval:
25 | val_loader = None
26 | else:
27 | val_loader = make_data_loader(cfg, is_train=False)
28 | trainer = make_trainer(cfg, network, train_loader)
29 | optimizer = make_optimizer(cfg, network)
30 | scheduler = make_lr_scheduler(cfg, optimizer)
31 | recorder = make_recorder(cfg)
32 | evaluator = make_evaluator(cfg)
33 |
34 | begin_epoch = load_model(network,
35 | optimizer,
36 | scheduler,
37 | recorder,
38 | cfg.trained_model_dir,
39 | resume=cfg.resume)
40 | if begin_epoch == 0 and cfg.pretrain != '':
41 | load_pretrain(network, cfg.pretrain)
42 |
43 |
44 | set_lr_scheduler(cfg, scheduler)
45 | psnr_best, ssim_best, lpips_best = 0, 0, 10
46 | for epoch in range(begin_epoch, cfg.train.epoch):
47 | recorder.epoch = epoch
48 | if cfg.distributed:
49 | train_loader.batch_sampler.sampler.set_epoch(epoch)
50 |
51 | train_loader.dataset.epoch = epoch
52 |
53 | trainer.train(epoch, train_loader, optimizer, recorder)
54 | scheduler.step()
55 |
56 | if (epoch + 1) % cfg.save_ep == 0 and cfg.local_rank == 0:
57 | save_model(network, optimizer, scheduler, recorder,
58 | cfg.trained_model_dir, epoch)
59 |
60 | if (epoch + 1) % cfg.save_latest_ep == 0 and cfg.local_rank == 0:
61 | save_model(network,
62 | optimizer,
63 | scheduler,
64 | recorder,
65 | cfg.trained_model_dir,
66 | epoch,
67 | last=True)
68 |
69 | if not cfg.skip_eval and (epoch + 1) % cfg.eval_ep == 0 and cfg.local_rank == 0:
70 | result = trainer.val(epoch, val_loader, evaluator, recorder)
71 | psnr = result['psnr']
72 | ssim = result['ssim']
73 | lpips = result['lpips']
74 | if psnr > psnr_best:
75 | psnr_best = psnr
76 | save_model(network,
77 | optimizer,
78 | scheduler,
79 | recorder,
80 | cfg.trained_model_dir,
81 | epoch,
82 | custom='psnr_best')
83 | if ssim > ssim_best:
84 | ssim_best = ssim
85 | save_model(network,
86 | optimizer,
87 | scheduler,
88 | recorder,
89 | cfg.trained_model_dir,
90 | epoch,
91 | custom='ssim_best')
92 | if lpips < lpips_best:
93 | lpips_best = lpips
94 | save_model(network,
95 | optimizer,
96 | scheduler,
97 | recorder,
98 | cfg.trained_model_dir,
99 | epoch,
100 | custom='lpips_best')
101 | print(f'psnr_best: {psnr_best:.2f}, ssim_best: {ssim_best:.3f}, lpips_best: {lpips_best:.3f}')
102 |
103 |
104 |
105 | return network
106 |
107 |
108 | def test(cfg, network):
109 | trainer = make_trainer(cfg, network)
110 | val_loader = make_data_loader(cfg, is_train=False)
111 | evaluator = make_evaluator(cfg)
112 | epoch = load_network(network,
113 | cfg.trained_model_dir,
114 | resume=cfg.resume,
115 | epoch=cfg.test.epoch)
116 | trainer.val(epoch, val_loader, evaluator)
117 |
118 | def synchronize():
119 | """
120 | Helper function to synchronize (barrier) among all processes when
121 | using distributed training
122 | """
123 | if not dist.is_available():
124 | return
125 | if not dist.is_initialized():
126 | return
127 | world_size = dist.get_world_size()
128 | if world_size == 1:
129 | return
130 | dist.barrier()
131 |
132 | def main():
133 | if cfg.distributed:
134 | cfg.local_rank = int(os.environ['RANK']) % torch.cuda.device_count()
135 | torch.cuda.set_device(cfg.local_rank)
136 | torch.distributed.init_process_group(backend="nccl",
137 | init_method="env://")
138 | synchronize()
139 |
140 | network = make_network(cfg)
141 | if args.test:
142 | test(cfg, network)
143 | else:
144 | train(cfg, network)
145 | if cfg.local_rank == 0:
146 | print('Success!')
147 | print('='*80)
148 | os.system('kill -9 {}'.format(os.getpid()))
149 |
150 |
151 | if __name__ == "__main__":
152 | main()
153 |
--------------------------------------------------------------------------------