├── LICENSE ├── README.md ├── assets └── overview.png ├── configs ├── default.yaml └── gefu │ ├── dtu │ ├── scan1.yaml │ ├── scan103.yaml │ ├── scan114.yaml │ ├── scan21.yaml │ └── scan8.yaml │ ├── dtu_pretrain.yaml │ ├── llff │ ├── fern.yaml │ ├── flower.yaml │ ├── fortress.yaml │ ├── horns.yaml │ ├── leaves.yaml │ ├── orchids.yaml │ ├── room.yaml │ └── trex.yaml │ ├── llff_eval.yaml │ ├── nerf │ ├── chair.yaml │ ├── drums.yaml │ ├── ficus.yaml │ ├── hotdog.yaml │ ├── lego.yaml │ ├── materials.yaml │ ├── mic.yaml │ └── ship.yaml │ └── nerf_eval.yaml ├── data └── mvsnerf │ ├── dtu_train_all.txt │ ├── dtu_val_all.txt │ └── pairs.th ├── lib ├── __init__.py ├── config │ ├── __init__.py │ ├── config.py │ └── yacs.py ├── datasets │ ├── __init__.py │ ├── collate_batch.py │ ├── dtu │ │ └── gefu.py │ ├── gefu_utils.py │ ├── llff │ │ └── gefu.py │ ├── make_dataset.py │ ├── nerf │ │ └── gefu.py │ ├── samplers.py │ └── video_utils.py ├── evaluators │ ├── __init__.py │ ├── gefu.py │ └── make_evaluator.py ├── networks │ ├── __init__.py │ ├── gefu │ │ ├── cost_reg_net.py │ │ ├── feature_net.py │ │ ├── nerf.py │ │ ├── network.py │ │ ├── res_unet.py │ │ └── utils.py │ └── make_network.py ├── train │ ├── __init__.py │ ├── losses │ │ ├── gefu.py │ │ ├── nerf.py │ │ ├── ssim_loss.py │ │ └── vgg_perceptual_loss.py │ ├── optimizer.py │ ├── recorder.py │ ├── scheduler.py │ └── trainers │ │ ├── __init__.py │ │ ├── make_trainer.py │ │ └── trainer.py └── utils │ ├── base_utils.py │ ├── data_config.py │ ├── data_utils.py │ ├── img_utils.py │ ├── mask_utils.py │ ├── net_utils.py │ ├── optimizer │ ├── lr_scheduler.py │ └── radam.py │ ├── ply_utils.py │ ├── rend_utils.py │ └── vis_utils.py ├── requirements.txt ├── run.py └── train_net.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Liu Tianqi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Geometry-aware Reconstruction and Fusion-refined Rendering for Generalizable Neural Radiance Fields 2 | 3 | PyTorch implementation of paper "Geometry-aware Reconstruction and Fusion-refined Rendering for Generalizable Neural Radiance Fields", CVPR 2024. 4 | 5 | > [Geometry-aware Reconstruction and Fusion-refined Rendering for Generalizable Neural Radiance Fields](https://arxiv.org/abs/2404.17528) 6 | > Tianqi Liu, Xinyi Ye, Min Shi, Zihao Huang, Zhiyu Pan, Zhan Peng, Zhiguo Cao* \ 7 | > CVPR 2024 8 | > [project page](https://gefucvpr24.github.io/) | [paper](https://arxiv.org/abs/2404.17528) | [poster](https://cvpr.thecvf.com/virtual/2024/poster/29889) | [model](https://drive.google.com/drive/folders/1pCCOLUj2fNAbp0ZXj7_tEPfF1vZvMbh-?usp=drive_link) 9 | 10 | 11 | ## Introduction 12 | Generalizable NeRF aims to synthesize novel views for unseen scenes. Common practices involve constructing variance-based cost volumes for geometry reconstruction and encoding 3D descriptors for decoding novel views. However, existing methods show limited generalization ability in challenging conditions due to inaccurate geometry, sub-optimal descriptors, and decoding strategies. We address these issues point by point. First, we find the variance-based cost volume exhibits failure patterns as the features of pixels corresponding to the same point can be inconsistent across different views due to occlusions or reflections. We introduce an Adaptive Cost Aggregation (ACA) approach to amplify the contribution of consistent pixel pairs and suppress inconsistent ones. Unlike previous methods that solely fuse 2D features into descriptors, our approach introduces a Spatial-View Aggregator (SVA) to incorporate 3D context into descriptors through spatial and inter-view interaction. When decoding the descriptors, we observe the two existing decoding strategies excel in different areas, which are complementary. A Consistency-Aware Fusion (CAF) strategy is proposed to leverage the advantages of both. We incorporate the above ACA, SVA, and CAF into a coarse-to-fine framework, termed Geometry-aware Reconstruction and Fusion-refined Rendering (GeFu). GeFu attains state-of-the-art performance across multiple datasets. 13 | 14 |

15 | 16 |

17 | 18 | ## Installation 19 | 20 | ### Clone this repository: 21 | 22 | ``` 23 | git clone https://github.com/TQTQliu/GeFu.git 24 | cd GeFu 25 | ``` 26 | 27 | ### Set up the python environment 28 | 29 | ``` 30 | conda create -n gefu python=3.8 31 | conda activate gefu 32 | pip install -r requirements.txt 33 | pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html 34 | ``` 35 | 36 | 37 | ## Datasets 38 | 39 | 40 | #### 1. DTU 41 | 42 | **Training data**. Download [DTU training data](https://drive.google.com/file/d/1eDjh-_bxKKnEuz5h-HXS7EDJn59clx6V/view) and [Depth raw](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/cascade-stereo/CasMVSNet/dtu_data/dtu_train_hr/Depths_raw.zip). Unzip and organize them as: 43 | ``` 44 | mvs_training 45 | ├── dtu 46 | ├── Cameras 47 | ├── Depths 48 | ├── Depths_raw 49 | └── Rectified 50 | ``` 51 | 52 | #### 2. NeRF Synthetic (Blender) and Real Forward-facing (LLFF) 53 | 54 | Download the [NeRF Synthetic](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) and [Real Forward-facing](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) datasets and unzip them. 55 | 56 | 57 | ## Usage 58 | ### Train generalizable model 59 | 60 | To train a generalizable model from scratch on DTU, specify ``data_root`` in ``configs/gefu/dtu_pretrain.yaml`` first and then run: 61 | ``` 62 | python train_net.py --cfg_file configs/gefu/dtu_pretrain.yaml 63 | ``` 64 | 65 | Our code also supports multi-gpu training. The released pretrained model was trained with 4 GPUs. 66 | ``` 67 | python -m torch.distributed.launch --nproc_per_node=4 train_net.py --cfg_file configs/gefu/dtu_pretrain.yaml distributed True gpus 0,1,2,3 68 | ``` 69 | 70 | 71 | 72 | ### Per-scene optimization 73 | Here we take the scan1 on the DTU as an example: 74 | ``` 75 | cd ./trained_model/gefu 76 | mkdir dtu_ft_scan1 77 | cp dtu_pretrain/latest.pth dtu_ft_scan1 78 | cd ../.. 79 | python train_net.py --cfg_file configs/gefu/dtu/scan1.yaml 80 | ``` 81 | 82 | We provide the finetuned models for each scenes [here](https://drive.google.com/drive/folders/11X_YI4BmYoRG1Q8AYnOnvPQmqQufqSMo?usp=drive_link). 83 | 84 | 85 | 86 | ### Evaluation 87 | 88 | #### Evaluate the pretrained model on DTU 89 | 90 | Download the [pretrained model](https://drive.google.com/drive/folders/1pCCOLUj2fNAbp0ZXj7_tEPfF1vZvMbh-?usp=drive_link) and put it into `trained_model/gefu/dtu_pretrain/latest.pth` 91 | 92 | Use the following command to evaluate the pretrained model on DTU: 93 | ``` 94 | python run.py --type evaluate --cfg_file configs/gefu/dtu_pretrain.yaml gefu.eval_depth True 95 | ``` 96 | The rendering images will be saved in ```result/gefu/dtu_pretrain```. Add the ```save_video True``` parameter at the end of the command to save the rendering videos. 97 | 98 | #### Evaluate the pretrained model on Real Forward-facing 99 | 100 | ``` 101 | python run.py --type evaluate --cfg_file configs/gefu/llff_eval.yaml 102 | ``` 103 | 104 | #### Evaluate the pretrained model on NeRF Synthetic 105 | ``` 106 | python run.py --type evaluate --cfg_file configs/gefu/nerf_eval.yaml 107 | ``` 108 | 109 | 110 | 111 | 112 | ## Citation 113 | If you find our work useful for your research, please cite our paper. 114 | 115 | ``` 116 | @InProceedings{Liu_2024_CVPR, 117 | author = {Liu, Tianqi and Ye, Xinyi and Shi, Min and Huang, Zihao and Pan, Zhiyu and Peng, Zhan and Cao, Zhiguo}, 118 | title = {Geometry-aware Reconstruction and Fusion-refined Rendering for Generalizable Neural Radiance Fields}, 119 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 120 | month = {June}, 121 | year = {2024}, 122 | pages = {7654-7663} 123 | } 124 | ``` 125 | 126 | ## Relevant Works 127 | 128 | - [**PixelNeRF: Neural Radiance Fields from One or Few Images**](https://alexyu.net/pixelnerf/), CVPR 2021
129 | 130 | - [**IBRNet: Learning Multi-View Image-Based Rendering**](https://ibrnet.github.io/), CVPR 2021
131 | 132 | - [**MVSNeRF: Fast Generalizable Radiance Field Reconstruction from Multi-View Stereo**](https://apchenstu.github.io/mvsnerf/), ICCV 2021
133 | 134 | - [**Neural Rays for Occlusion-aware Image-based Rendering**](https://liuyuan-pal.github.io/NeuRay/), CVPR 2022
135 | 136 | - [**ENeRF: Efficient Neural Radiance Fields for Interactive Free-viewpoint Video**](https://zju3dv.github.io/enerf/), SIGGRAPH Asia 2022
137 | 138 | - [**Is Attention All NeRF Needs?**](https://vita-group.github.io/GNT/), ICLR 2023
139 | 140 | - [**Explicit Correspondence Matching for Generalizable Neural Radiance Fields**](https://donydchen.github.io/matchnerf/), arXiv 2023
141 | 142 | 143 | ## Acknowledgement 144 | 145 | The project is mainly based on [ENeRF](https://github.com/zju3dv/ENeRF?tab=readme-ov-file). Many thanks for their excellent contributions! When using our code, please also pay attention to the license of ENeRF. 146 | -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TQTQliu/GeFu/86443625c6ae69a531b81a2e672d3922d82f0e70/assets/overview.png -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TQTQliu/GeFu/86443625c6ae69a531b81a2e672d3922d82f0e70/configs/default.yaml -------------------------------------------------------------------------------- /configs/gefu/dtu/scan1.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/gefu/dtu_pretrain.yaml 2 | exp_name: dtu_ft_scan1 3 | gefu: 4 | test_input_views: 4 5 | train_dataset: 6 | scene: scan1 7 | test_dataset: 8 | scene: scan1 9 | train: 10 | epoch: 222 # pretrained epoch + 6 11 | save_ep: 1 12 | eval_ep: 1 13 | -------------------------------------------------------------------------------- /configs/gefu/dtu/scan103.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/gefu/dtu_pretrain.yaml 2 | exp_name: dtu_ft_scan103 3 | gefu: 4 | test_input_views: 4 5 | train_dataset: 6 | scene: scan103 7 | test_dataset: 8 | scene: scan103 9 | train: 10 | epoch: 222 # pretrained epoch + 6 11 | lr: 5e-5 12 | save_ep: 1 13 | eval_ep: 1 14 | -------------------------------------------------------------------------------- /configs/gefu/dtu/scan114.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/gefu/dtu_pretrain.yaml 2 | exp_name: dtu_ft_scan114 3 | gefu: 4 | test_input_views: 4 5 | train_dataset: 6 | scene: scan114 7 | test_dataset: 8 | scene: scan114 9 | train: 10 | epoch: 222 # pretrained epoch + 6 11 | lr: 5e-5 12 | save_ep: 1 13 | eval_ep: 1 14 | -------------------------------------------------------------------------------- /configs/gefu/dtu/scan21.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/gefu/dtu_pretrain.yaml 2 | exp_name: dtu_ft_scan21 3 | gefu: 4 | test_input_views: 4 5 | train_dataset: 6 | scene: scan21 7 | test_dataset: 8 | scene: scan21 9 | train: 10 | epoch: 222 # pretrained epoch + 6 11 | lr: 5e-6 12 | save_ep: 1 13 | eval_ep: 1 14 | -------------------------------------------------------------------------------- /configs/gefu/dtu/scan8.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/gefu/dtu_pretrain.yaml 2 | exp_name: dtu_ft_scan8 3 | gefu: 4 | test_input_views: 4 5 | train_dataset: 6 | scene: scan8 7 | test_dataset: 8 | scene: scan8 9 | train: 10 | epoch: 222 # pretrained epoch + 6 11 | save_ep: 1 12 | eval_ep: 1 13 | 14 | -------------------------------------------------------------------------------- /configs/gefu/dtu_pretrain.yaml: -------------------------------------------------------------------------------- 1 | task: gefu 2 | gpus: [0] 3 | exp_name: 'dtu_pretrain' 4 | 5 | 6 | # module 7 | train_dataset_module: lib.datasets.dtu.gefu 8 | test_dataset_module: lib.datasets.dtu.gefu 9 | network_module: lib.networks.gefu.network 10 | loss_module: lib.train.losses.gefu 11 | evaluator_module: lib.evaluators.gefu 12 | 13 | save_result: True 14 | eval_lpips: True 15 | save_video: False 16 | save_ply: False 17 | 18 | 19 | # task config 20 | gefu: 21 | train_input_views: [2, 3, 4] 22 | train_input_views_prob: [0.1, 0.8, 0.1] 23 | test_input_views: 3 24 | viewdir_agg: True 25 | chunk_size: 1000000 26 | white_bkgd: False 27 | eval_depth: False 28 | eval_center: False # only for llff evaluation (same as MVSNeRF: https://github.com/apchenstu/mvsnerf/blob/1fdf6487389d0872dade614b3cea61f7b099406e/renderer.ipynb) 29 | reweighting: False 30 | cas_config: 31 | num: 2 32 | depth_inv: [True, False] 33 | volume_scale: [0.125, 0.5] 34 | volume_planes: [64, 8] 35 | im_feat_scale: [0.25, 0.5] 36 | im_ibr_scale: [0.25, 1.] 37 | render_scale: [0.25, 1.0] 38 | render_im_feat_level: [0, 2] 39 | nerf_model_feat_ch: [32, 8] 40 | render_if: [True, True] 41 | num_samples: [8, 2] 42 | num_rays: [4096, 32768] 43 | num_patchs: [0, 0] 44 | train_img: [True, True] 45 | patch_size: [-1, -1] 46 | loss_weight: [0.5, 1.] 47 | 48 | 49 | 50 | train_dataset: 51 | data_root: 'mvs_training/dtu/' 52 | ann_file: 'data/mvsnerf/dtu_train_all.txt' 53 | split: 'train' 54 | batch_size: 2 55 | input_ratio: 1. 56 | 57 | test_dataset: 58 | data_root: 'mvs_training/dtu/' 59 | ann_file: 'data/mvsnerf/dtu_val_all.txt' 60 | split: 'test' 61 | batch_size: 1 62 | input_ratio: 1. 63 | 64 | train: 65 | batch_size: 1 66 | lr: 5e-4 67 | weight_decay: 0. 68 | epoch: 300 69 | scheduler: 70 | type: 'exponential' 71 | gamma: 0.5 72 | decay_epochs: 50 73 | batch_sampler: 'gefu' 74 | collator: 'gefu' 75 | sampler_meta: 76 | input_views_num: [2, 3, 4] 77 | input_views_prob: [0.1, 0.8, 0.1] 78 | num_workers: 4 79 | 80 | test: 81 | batch_size: 1 82 | collator: 'gefu' 83 | batch_sampler: 'gefu' 84 | sampler_meta: 85 | input_views_num: [3] 86 | input_views_prob: [1.] 87 | 88 | ep_iter: 1000 89 | save_ep: 1 90 | eval_ep: 1 91 | save_latest_ep: 1 92 | log_interval: 1 93 | -------------------------------------------------------------------------------- /configs/gefu/llff/fern.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/gefu/llff_eval.yaml 2 | exp_name: llff_ft_fern 3 | 4 | gefu: 5 | test_input_views: 4 6 | train_input_views: [3, 4] 7 | train_input_views_prob: [0.4, 0.6] 8 | cas_config: 9 | render_if: [True, True] 10 | train_dataset: 11 | scene: fern 12 | test_dataset: 13 | scene: fern 14 | train: 15 | epoch: 222 # pretrained epoch + 6 16 | lr: 5e-5 17 | sampler_meta: 18 | input_views_num: [3, 4] 19 | input_views_prob: [0.4, 0.6] 20 | save_ep: 1 21 | eval_ep: 1 22 | -------------------------------------------------------------------------------- /configs/gefu/llff/flower.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/gefu/llff_eval.yaml 2 | exp_name: llff_ft_flower 3 | 4 | gefu: 5 | test_input_views: 4 6 | train_input_views: [3, 4] 7 | train_input_views_prob: [0.4, 0.6] 8 | cas_config: 9 | render_if: [True, True] 10 | train_dataset: 11 | scene: flower 12 | test_dataset: 13 | scene: flower 14 | train: 15 | epoch: 222 # pretrained epoch + 6 16 | lr: 5e-5 17 | sampler_meta: 18 | input_views_num: [3, 4] 19 | input_views_prob: [0.4, 0.6] 20 | save_ep: 1 21 | eval_ep: 1 22 | -------------------------------------------------------------------------------- /configs/gefu/llff/fortress.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/gefu/llff_eval.yaml 2 | exp_name: llff_ft_fortress 3 | 4 | gefu: 5 | test_input_views: 4 6 | train_input_views: [3, 4] 7 | train_input_views_prob: [0.4, 0.6] 8 | cas_config: 9 | render_if: [True, True] 10 | train_dataset: 11 | scene: fortress 12 | test_dataset: 13 | scene: fortress 14 | train: 15 | epoch: 222 # pretrained epoch + 6 16 | lr: 5e-6 17 | sampler_meta: 18 | input_views_num: [3, 4] 19 | input_views_prob: [0.4, 0.6] 20 | save_ep: 1 21 | eval_ep: 1 22 | -------------------------------------------------------------------------------- /configs/gefu/llff/horns.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/gefu/llff_eval.yaml 2 | exp_name: llff_ft_horns 3 | 4 | gefu: 5 | test_input_views: 4 6 | train_input_views: [3, 4] 7 | train_input_views_prob: [0.4, 0.6] 8 | cas_config: 9 | render_if: [True, True] 10 | train_dataset: 11 | scene: horns 12 | test_dataset: 13 | scene: horns 14 | train: 15 | epoch: 222 # pretrained epoch + 6 16 | lr: 5e-5 17 | sampler_meta: 18 | input_views_num: [3, 4] 19 | input_views_prob: [0.4, 0.6] 20 | save_ep: 1 21 | eval_ep: 1 22 | -------------------------------------------------------------------------------- /configs/gefu/llff/leaves.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/gefu/llff_eval.yaml 2 | exp_name: llff_ft_leaves 3 | 4 | gefu: 5 | test_input_views: 4 6 | train_input_views: [3, 4] 7 | train_input_views_prob: [0.4, 0.6] 8 | cas_config: 9 | render_if: [True, True] 10 | train_dataset: 11 | scene: leaves 12 | test_dataset: 13 | scene: leaves 14 | train: 15 | epoch: 222 # pretrained epoch + 6 16 | lr: 5e-5 17 | sampler_meta: 18 | input_views_num: [3, 4] 19 | input_views_prob: [0.4, 0.6] 20 | save_ep: 1 21 | eval_ep: 1 22 | -------------------------------------------------------------------------------- /configs/gefu/llff/orchids.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/gefu/llff_eval.yaml 2 | exp_name: llff_ft_orchids 3 | 4 | gefu: 5 | test_input_views: 4 6 | train_input_views: [3, 4] 7 | train_input_views_prob: [0.4, 0.6] 8 | cas_config: 9 | render_if: [True, True] 10 | train_dataset: 11 | scene: orchids 12 | test_dataset: 13 | scene: orchids 14 | train: 15 | epoch: 222 # pretrained epoch + 6 16 | lr: 5e-5 17 | sampler_meta: 18 | input_views_num: [3, 4] 19 | input_views_prob: [0.4, 0.6] 20 | save_ep: 1 21 | eval_ep: 1 22 | -------------------------------------------------------------------------------- /configs/gefu/llff/room.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/gefu/llff_eval.yaml 2 | exp_name: llff_ft_room 3 | 4 | gefu: 5 | test_input_views: 4 6 | train_input_views: [3, 4] 7 | train_input_views_prob: [0.4, 0.6] 8 | cas_config: 9 | render_if: [True, True] 10 | train_dataset: 11 | scene: room 12 | test_dataset: 13 | scene: room 14 | train: 15 | epoch: 222 # pretrained epoch + 6 16 | lr: 1e-3 17 | sampler_meta: 18 | input_views_num: [3, 4] 19 | input_views_prob: [0.4, 0.6] 20 | save_ep: 1 21 | eval_ep: 1 22 | -------------------------------------------------------------------------------- /configs/gefu/llff/trex.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/gefu/llff_eval.yaml 2 | exp_name: llff_ft_trex 3 | 4 | gefu: 5 | test_input_views: 4 6 | train_input_views: [3, 4] 7 | train_input_views_prob: [0.4, 0.6] 8 | cas_config: 9 | render_if: [True, True] 10 | train_dataset: 11 | scene: trex 12 | test_dataset: 13 | scene: trex 14 | train: 15 | epoch: 222 # pretrained epoch + 6 16 | lr: 5e-4 17 | sampler_meta: 18 | input_views_num: [3, 4] 19 | input_views_prob: [0.4, 0.6] 20 | save_ep: 1 21 | eval_ep: 1 22 | -------------------------------------------------------------------------------- /configs/gefu/llff_eval.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/gefu/dtu_pretrain.yaml 2 | 3 | train_dataset_module: lib.datasets.llff.gefu 4 | test_dataset_module: lib.datasets.llff.gefu 5 | 6 | gefu: 7 | eval_center: True 8 | reweighting: True 9 | cas_config: 10 | render_if: [True, True] 11 | volume_planes: [32, 8] 12 | 13 | train_dataset: 14 | data_root: 'nerf_llff_data' 15 | split: 'train' 16 | # input_h_w: [640, 960] # OOM for RTX 3090 17 | input_h_w: [512, 640] 18 | batch_size: 1 19 | input_ratio: 1. 20 | 21 | test_dataset: 22 | data_root: 'nerf_llff_data' 23 | split: 'test' 24 | batch_size: 1 25 | input_h_w: [640, 960] 26 | input_ratio: 1. 27 | -------------------------------------------------------------------------------- /configs/gefu/nerf/chair.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/gefu/nerf_eval.yaml 2 | exp_name: nerf_ft_chair 3 | gefu: 4 | test_input_views: 4 5 | train_input_views: [3, 4] 6 | train_input_views_prob: [0.4, 0.6] 7 | cas_config: 8 | render_if: [True, True] 9 | train_dataset: 10 | scene: chair 11 | test_dataset: 12 | scene: chair 13 | train: 14 | epoch: 222 # pretrained epoch + 6 15 | lr: 3e-4 16 | save_ep: 1 17 | eval_ep: 1 18 | -------------------------------------------------------------------------------- /configs/gefu/nerf/drums.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/gefu/nerf_eval.yaml 2 | exp_name: nerf_ft_drums 3 | gefu: 4 | test_input_views: 4 5 | train_input_views: [3, 4] 6 | train_input_views_prob: [0.4, 0.6] 7 | cas_config: 8 | render_if: [True, True] 9 | train_dataset: 10 | scene: drums 11 | test_dataset: 12 | scene: drums 13 | train: 14 | epoch: 222 # pretrained epoch + 6 15 | lr: 5e-5 16 | save_ep: 1 17 | eval_ep: 1 18 | -------------------------------------------------------------------------------- /configs/gefu/nerf/ficus.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/gefu/nerf_eval.yaml 2 | exp_name: nerf_ft_ficus 3 | gefu: 4 | test_input_views: 4 5 | train_input_views: [3, 4] 6 | train_input_views_prob: [0.4, 0.6] 7 | cas_config: 8 | render_if: [True, True] 9 | train_dataset: 10 | scene: ficus 11 | test_dataset: 12 | scene: ficus 13 | train: 14 | epoch: 225 # pretrained epoch + 6 15 | lr: 8e-5 16 | save_ep: 1 17 | eval_ep: 1 18 | -------------------------------------------------------------------------------- /configs/gefu/nerf/hotdog.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/gefu/nerf_eval.yaml 2 | exp_name: nerf_ft_hotdog 3 | gefu: 4 | test_input_views: 4 5 | train_input_views: [3, 4] 6 | train_input_views_prob: [0.4, 0.6] 7 | cas_config: 8 | render_if: [True, True] 9 | train_dataset: 10 | scene: hotdog 11 | test_dataset: 12 | scene: hotdog 13 | train: 14 | epoch: 222 # pretrained epoch + 6 15 | lr: 5e-4 16 | save_ep: 1 17 | eval_ep: 1 18 | -------------------------------------------------------------------------------- /configs/gefu/nerf/lego.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/gefu/nerf_eval.yaml 2 | exp_name: nerf_ft_lego 3 | gefu: 4 | test_input_views: 4 5 | train_input_views: [3, 4] 6 | train_input_views_prob: [0.4, 0.6] 7 | cas_config: 8 | render_if: [True, True] 9 | train_dataset: 10 | scene: lego 11 | test_dataset: 12 | scene: lego 13 | train: 14 | epoch: 222 # pretrained epoch + 6 15 | lr: 5e-4 16 | save_ep: 1 17 | eval_ep: 1 18 | -------------------------------------------------------------------------------- /configs/gefu/nerf/materials.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/gefu/nerf_eval.yaml 2 | exp_name: nerf_ft_materials 3 | gefu: 4 | test_input_views: 4 5 | train_input_views: [3, 4] 6 | train_input_views_prob: [0.4, 0.6] 7 | cas_config: 8 | render_if: [True, True] 9 | train_dataset: 10 | scene: materials 11 | test_dataset: 12 | scene: materials 13 | train: 14 | epoch: 222 # pretrained epoch + 6 15 | lr: 6e-4 16 | save_ep: 1 17 | eval_ep: 1 18 | -------------------------------------------------------------------------------- /configs/gefu/nerf/mic.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/gefu/nerf_eval.yaml 2 | exp_name: nerf_ft_mic 3 | gefu: 4 | test_input_views: 4 5 | train_input_views: [3, 4] 6 | train_input_views_prob: [0.4, 0.6] 7 | cas_config: 8 | render_if: [True, True] 9 | train_dataset: 10 | scene: mic 11 | test_dataset: 12 | scene: mic 13 | train: 14 | epoch: 222 # pretrained epoch + 6 15 | lr: 4e-4 16 | save_ep: 1 17 | eval_ep: 1 18 | -------------------------------------------------------------------------------- /configs/gefu/nerf/ship.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/gefu/nerf_eval.yaml 2 | exp_name: nerf_ft_ship 3 | gefu: 4 | test_input_views: 4 5 | train_input_views: [3, 4] 6 | train_input_views_prob: [0.4, 0.6] 7 | cas_config: 8 | render_if: [True, True] 9 | train_dataset: 10 | scene: ship 11 | test_dataset: 12 | scene: ship 13 | train: 14 | epoch: 222 # pretrained epoch + 6 15 | lr: 8e-4 16 | save_ep: 1 17 | eval_ep: 1 18 | -------------------------------------------------------------------------------- /configs/gefu/nerf_eval.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: configs/gefu/dtu_pretrain.yaml 2 | 3 | train_dataset_module: lib.datasets.nerf.gefu 4 | test_dataset_module: lib.datasets.nerf.gefu 5 | 6 | gefu: 7 | reweighting: True 8 | cas_config: 9 | render_if: [True, True] 10 | 11 | train_dataset: 12 | data_root: 'nerf_synthetic' 13 | split: 'train' 14 | # input_h_w: [800, 800] # OOM for RTX 3090 15 | input_h_w: [512, 640] 16 | batch_size: 1 17 | input_ratio: 1. 18 | 19 | test_dataset: 20 | data_root: 'nerf_synthetic' 21 | split: 'test' 22 | input_h_w: [800, 800] 23 | batch_size: 1 24 | input_ratio: 1. 25 | 26 | 27 | -------------------------------------------------------------------------------- /data/mvsnerf/dtu_train_all.txt: -------------------------------------------------------------------------------- 1 | scan3 2 | scan4 3 | scan5 4 | scan6 5 | scan9 6 | scan10 7 | scan11 8 | scan12 9 | scan13 10 | scan14 11 | scan15 12 | scan16 13 | scan17 14 | scan18 15 | scan19 16 | scan20 17 | scan22 18 | scan23 19 | scan24 20 | scan28 21 | scan32 22 | scan33 23 | scan35 24 | scan36 25 | scan37 26 | scan42 27 | scan43 28 | scan44 29 | scan46 30 | scan47 31 | scan48 32 | scan49 33 | scan50 34 | scan52 35 | scan53 36 | scan59 37 | scan60 38 | scan61 39 | scan62 40 | scan64 41 | scan65 42 | scan66 43 | scan67 44 | scan68 45 | scan69 46 | scan70 47 | scan71 48 | scan72 49 | scan74 50 | scan75 51 | scan76 52 | scan77 53 | scan84 54 | scan85 55 | scan86 56 | scan87 57 | scan88 58 | scan89 59 | scan90 60 | scan91 61 | scan92 62 | scan93 63 | scan94 64 | scan95 65 | scan96 66 | scan97 67 | scan98 68 | scan99 69 | scan100 70 | scan101 71 | scan102 72 | scan104 73 | scan105 74 | scan106 75 | scan107 76 | scan108 77 | scan109 78 | scan118 79 | scan119 80 | scan120 81 | scan121 82 | scan122 83 | scan123 84 | scan124 85 | scan125 86 | scan126 87 | scan127 88 | scan128 -------------------------------------------------------------------------------- /data/mvsnerf/dtu_val_all.txt: -------------------------------------------------------------------------------- 1 | scan1 2 | scan8 3 | scan21 4 | scan30 5 | scan31 6 | scan34 7 | scan38 8 | scan40 9 | scan41 10 | scan45 11 | scan55 12 | scan63 13 | scan82 14 | scan103 15 | scan110 16 | scan114 17 | -------------------------------------------------------------------------------- /data/mvsnerf/pairs.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TQTQliu/GeFu/86443625c6ae69a531b81a2e672d3922d82f0e70/data/mvsnerf/pairs.th -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TQTQliu/GeFu/86443625c6ae69a531b81a2e672d3922d82f0e70/lib/__init__.py -------------------------------------------------------------------------------- /lib/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import cfg, args 2 | -------------------------------------------------------------------------------- /lib/config/config.py: -------------------------------------------------------------------------------- 1 | from .yacs import CfgNode as CN 2 | import argparse 3 | import os 4 | import numpy as np 5 | from . import yacs 6 | 7 | 8 | cfg = CN() 9 | 10 | os.environ['workspace'] = "." 11 | cfg.workspace = os.environ['workspace'] 12 | print('Workspace: ', cfg.workspace) 13 | 14 | 15 | cfg.level = 32. 16 | cfg.resolution = 256 17 | 18 | cfg.vis_encoder = '' 19 | cfg.feat_vis_len = 8 20 | 21 | cfg.cache_data = False 22 | cfg.sample_keypoints_epoch = -1 23 | 24 | cfg.write_video = False 25 | cfg.interested_mask = False 26 | cfg.render_path = False 27 | cfg.render_emb = 0 28 | cfg.render_ixt = 0 29 | cfg.code_id = -1 30 | cfg.time_weight = 0. 31 | cfg.render_static = True 32 | cfg.pretrain_path = '' 33 | cfg.scene = 'test' 34 | cfg.last_view = False 35 | cfg.exp_hard = False 36 | cfg.pos_encoding_t = False 37 | cfg.render_time = False 38 | cfg.render_time_skip = [0, -1, 1] 39 | cfg.start_time = [2009, 8, 1] 40 | cfg.end_time = [2013, 12, 1] 41 | cfg.discrete_3views = False 42 | cfg.fps = 24 43 | cfg.dcat = False 44 | cfg.min_y = -100000000. 45 | cfg.time_discrete = -1 46 | cfg.render_day = False 47 | cfg.render_date = [2013, 1, 1] 48 | cfg.rand_t = -1. 49 | cfg.semantic_mask = False 50 | cfg.product_combine = False 51 | cfg.unisample = False 52 | cfg.render_emb_2 = -1 53 | cfg.render_num = 30 54 | cfg.render_ext = 0 55 | cfg.time_geo = False 56 | cfg.reg_beta = False 57 | cfg.fix_beta = False 58 | cfg.hard_lap = False 59 | cfg.render_octree = False 60 | cfg.render_mask = False 61 | cfg.environment_map = False 62 | 63 | cfg.save_result = False 64 | cfg.clear_result = False 65 | cfg.save_tag = 'default' 66 | # module 67 | cfg.train_dataset_module = 'lib.datasets.dtu.neus' 68 | cfg.test_dataset_module = 'lib.datasets.dtu.neus' 69 | cfg.val_dataset_module = 'lib.datasets.dtu.neus' 70 | cfg.network_module = 'lib.neworks.neus.neus' 71 | cfg.loss_module = 'lib.train.losses.neus' 72 | cfg.evaluator_module = 'lib.evaluators.neus' 73 | 74 | # experiment name 75 | cfg.exp_name = 'gitbranch_hello' 76 | cfg.exp_name_tag = '' 77 | cfg.pretrain = '' 78 | 79 | # network 80 | cfg.distributed = False 81 | 82 | # task 83 | cfg.task = 'hello' 84 | 85 | # gpus 86 | cfg.gpus = list(range(4)) 87 | # if load the pretrained network 88 | cfg.resume = True 89 | 90 | # epoch 91 | cfg.ep_iter = -1 92 | cfg.save_ep = 1 93 | cfg.save_latest_ep = 1 94 | cfg.eval_ep = 1 95 | log_interval: 20 96 | 97 | 98 | cfg.task_arg = CN() 99 | cfg.task_arg.sample_more_on_mask = -1. 100 | cfg.task_arg.sample_on_mask = False 101 | 102 | # ----------------------------------------------------------------------------- 103 | # train 104 | # ----------------------------------------------------------------------------- 105 | cfg.train = CN() 106 | cfg.train.epoch = 10000 107 | cfg.train.num_workers = 8 108 | cfg.train.collator = 'default' 109 | cfg.train.batch_sampler = 'default' 110 | cfg.train.sampler_meta = CN({}) 111 | cfg.train.shuffle = True 112 | cfg.train.eps = 1e-8 113 | 114 | # use adam as default 115 | cfg.train.optim = 'adam' 116 | cfg.train.lr = 5e-4 117 | cfg.train.weight_decay = 0. 118 | cfg.train.scheduler = CN({'type': 'multi_step', 'milestones': [80, 120, 200, 240], 'gamma': 0.5}) 119 | cfg.train.batch_size = 4 120 | 121 | # test 122 | cfg.test = CN() 123 | cfg.test.batch_size = 1 124 | cfg.test.collator = 'default' 125 | cfg.test.epoch = -1 126 | cfg.test.batch_sampler = 'default' 127 | cfg.test.sampler_meta = CN({}) 128 | 129 | # trained model 130 | cfg.trained_model_dir = os.path.join(os.environ['workspace'], 'trained_model') 131 | cfg.clean_tag = 'debug' 132 | 133 | # recorder 134 | cfg.record_dir = os.path.join(os.environ['workspace'], 'record') 135 | 136 | # result 137 | cfg.result_dir = os.path.join(os.environ['workspace'], 'result') 138 | 139 | # evaluation 140 | cfg.skip_eval = False 141 | 142 | cfg.fix_random = False 143 | 144 | def parse_cfg(cfg, args): 145 | if len(cfg.task) == 0: 146 | raise ValueError('task must be specified') 147 | 148 | # assign the gpus 149 | if -1 not in cfg.gpus: 150 | os.environ['CUDA_VISIBLE_DEVICES'] = ', '.join([str(gpu) for gpu in cfg.gpus]) 151 | 152 | if 'bbox' in cfg: 153 | bbox = np.array(cfg.bbox).reshape((2, 3)) 154 | center, half_size = np.mean(bbox, axis=0), (bbox[1]-bbox[0]).max().item() / 2. 155 | bbox = np.stack([center-half_size, center+half_size]) 156 | cfg.bbox = bbox.reshape(6).tolist() 157 | 158 | if len(cfg.exp_name_tag) != 0: 159 | cfg.exp_name += ('_' + cfg.exp_name_tag) 160 | cfg.exp_name = cfg.exp_name.replace('gitbranch', os.popen('git describe --all').readline().strip()[6:]) 161 | cfg.exp_name = cfg.exp_name.replace('gitcommit', os.popen('git describe --tags --always').readline().strip()) 162 | print('EXP NAME: ', cfg.exp_name) 163 | cfg.trained_model_dir = os.path.join(cfg.trained_model_dir, cfg.task, cfg.exp_name) 164 | cfg.record_dir = os.path.join(cfg.record_dir, cfg.task, cfg.exp_name) 165 | cfg.result_dir = os.path.join(cfg.result_dir, cfg.task, cfg.exp_name, cfg.save_tag) 166 | cfg.local_rank = args.local_rank 167 | modules = [key for key in cfg if '_module' in key] 168 | for module in modules: 169 | cfg[module.replace('_module', '_path')] = cfg[module].replace('.', '/') + '.py' 170 | 171 | def make_cfg(args): 172 | def merge_cfg(cfg_file, cfg): 173 | with open(cfg_file, 'r') as f: 174 | current_cfg = yacs.load_cfg(f) 175 | if 'parent_cfg' in current_cfg.keys(): 176 | cfg = merge_cfg(current_cfg.parent_cfg, cfg) 177 | cfg.merge_from_other_cfg(current_cfg) 178 | else: 179 | cfg.merge_from_other_cfg(current_cfg) 180 | print(cfg_file) 181 | return cfg 182 | cfg_ = merge_cfg(args.cfg_file, cfg) 183 | try: 184 | index = args.opts.index('other_opts') 185 | cfg_.merge_from_list(args.opts[:index]) 186 | except: 187 | cfg_.merge_from_list(args.opts) 188 | parse_cfg(cfg_, args) 189 | return cfg_ 190 | 191 | 192 | parser = argparse.ArgumentParser() 193 | parser.add_argument("--cfg_file", default="configs/gefu/dtu_pretrain.yaml", type=str) 194 | parser.add_argument('--test', action='store_true', dest='test', default=False) 195 | parser.add_argument("--type", type=str, default="") 196 | parser.add_argument('--det', type=str, default='') 197 | parser.add_argument('--local_rank', type=int, default=0) 198 | parser.add_argument("opts", default=None, nargs=argparse.REMAINDER) 199 | args = parser.parse_args() 200 | if len(args.type) > 0: 201 | cfg.task = "run" 202 | cfg = make_cfg(args) 203 | 204 | -------------------------------------------------------------------------------- /lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_dataset import make_data_loader 2 | -------------------------------------------------------------------------------- /lib/datasets/collate_batch.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataloader import default_collate 2 | import torch 3 | import numpy as np 4 | from lib.config import cfg 5 | 6 | _collators = {} 7 | 8 | def make_collator(cfg, is_train): 9 | collator = cfg.train.collator if is_train else cfg.test.collator 10 | if collator in _collators: 11 | return _collators[collator] 12 | else: 13 | return default_collate 14 | -------------------------------------------------------------------------------- /lib/datasets/dtu/gefu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from lib.datasets import gefu_utils 4 | from lib.config import cfg 5 | import imageio 6 | import cv2 7 | import random 8 | from lib.config import cfg 9 | from lib.utils import data_utils 10 | import torch 11 | from lib.datasets.video_utils import * 12 | 13 | 14 | if cfg.fix_random: 15 | random.seed(0) 16 | np.random.seed(0) 17 | 18 | class Dataset: 19 | def __init__(self, **kwargs): 20 | super(Dataset, self).__init__() 21 | self.data_root = os.path.join(cfg.workspace, kwargs['data_root']) 22 | self.split = kwargs['split'] 23 | if 'scene' in kwargs: 24 | self.scenes = [kwargs['scene']] 25 | else: 26 | self.scenes = [] 27 | self.build_metas(kwargs['ann_file']) 28 | self.depth_ranges = [425., 905.] 29 | 30 | def build_metas(self, ann_file): 31 | scenes = [line.strip() for line in open(ann_file).readlines()] 32 | dtu_pairs = torch.load('data/mvsnerf/pairs.th') 33 | self.scene_infos = {} 34 | self.metas = [] 35 | if len(self.scenes) != 0: 36 | scenes = self.scenes 37 | 38 | for scene in scenes: 39 | scene_info = {'ixts': [], 'exts': [], 'dpt_paths': [], 'img_paths': []} 40 | for i in range(49): 41 | cam_path = os.path.join(self.data_root, 'Cameras/train/{:08d}_cam.txt'.format(i)) 42 | ixt, ext, _ = data_utils.read_cam_file(cam_path) 43 | ext[:3, 3] = ext[:3, 3] 44 | ixt[:2] = ixt[:2] * 4 45 | dpt_path = os.path.join(self.data_root, 'Depths_raw/{}/depth_map_{:04d}.pfm'.format(scene, i)) 46 | img_path = os.path.join(self.data_root, 'Rectified/{}_train/rect_{:03d}_3_r5000.png'.format(scene, i+1)) 47 | scene_info['ixts'].append(ixt.astype(np.float32)) 48 | scene_info['exts'].append(ext.astype(np.float32)) 49 | scene_info['dpt_paths'].append(dpt_path) 50 | scene_info['img_paths'].append(img_path) 51 | 52 | if self.split == 'train' and len(self.scenes) != 1: 53 | train_ids = np.arange(49).tolist() 54 | test_ids = np.arange(49).tolist() 55 | elif self.split == 'train' and len(self.scenes) == 1: 56 | train_ids = dtu_pairs['dtu_train'] 57 | test_ids = dtu_pairs['dtu_train'] 58 | else: 59 | train_ids = dtu_pairs['dtu_train'] 60 | test_ids = dtu_pairs['dtu_val'] 61 | scene_info.update({'train_ids': train_ids, 'test_ids': test_ids}) 62 | self.scene_infos[scene] = scene_info 63 | 64 | cam_points = np.array([np.linalg.inv(scene_info['exts'][i])[:3, 3] for i in train_ids]) 65 | for tar_view in test_ids: 66 | cam_point = np.linalg.inv(scene_info['exts'][tar_view])[:3, 3] 67 | distance = np.linalg.norm(cam_points - cam_point[None], axis=-1) 68 | argsorts = distance.argsort() 69 | argsorts = argsorts[1:] if tar_view in train_ids else argsorts 70 | input_views_num = cfg.gefu.train_input_views[1] + 1 if self.split == 'train' else cfg.gefu.test_input_views 71 | src_views = [train_ids[i] for i in argsorts[:input_views_num]] 72 | self.metas += [(scene, tar_view, src_views)] 73 | 74 | def __getitem__(self, index_meta): 75 | index, input_views_num = index_meta 76 | scene, tar_view, src_views = self.metas[index] 77 | # src_views = [tar_view] + src_views # only used for depth evaluation under reference view 78 | if self.split == 'train': 79 | if random.random() < 0.1: 80 | src_views = src_views + [tar_view] 81 | src_views = random.sample(src_views[:input_views_num+1], input_views_num) 82 | scene_info = self.scene_infos[scene] 83 | 84 | tar_img = np.array(imageio.imread(scene_info['img_paths'][tar_view])) / 255. 85 | H, W = tar_img.shape[:2] 86 | tar_ext, tar_ixt = scene_info['exts'][tar_view], scene_info['ixts'][tar_view] 87 | if self.split != 'train': # only used for evaluation 88 | tar_dpt = data_utils.read_pfm(scene_info['dpt_paths'][tar_view])[0].astype(np.float32) 89 | tar_dpt = cv2.resize(tar_dpt, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_NEAREST) 90 | tar_dpt = tar_dpt[44:556, 80:720] 91 | tar_mask = (tar_dpt > 0.).astype(np.uint8) 92 | else: 93 | # tar_dpt = np.ones_like(tar_img) 94 | # tar_mask = np.ones_like(tar_img) 95 | tar_dpt = data_utils.read_pfm(scene_info['dpt_paths'][tar_view])[0].astype(np.float32) 96 | tar_dpt = cv2.resize(tar_dpt, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_NEAREST) 97 | tar_dpt = tar_dpt[44:556, 80:720] 98 | tar_mask = (tar_dpt > 0.).astype(np.uint8) 99 | 100 | src_inps, src_exts, src_ixts = self.read_src(scene_info, src_views) 101 | 102 | ret = {'src_inps': src_inps, 103 | 'src_exts': src_exts, 104 | 'src_ixts': src_ixts} 105 | ret.update({'tar_ext': tar_ext, 106 | 'tar_ixt': tar_ixt}) 107 | if self.split != 'train': 108 | ret.update({'tar_img': tar_img, 109 | 'tar_dpt': tar_dpt, 110 | 'tar_mask': tar_mask}) 111 | ret.update({'near_far': np.array(self.depth_ranges).astype(np.float32)}) 112 | ret.update({'meta': {'scene': scene, 'tar_view': tar_view, 'frame_id': 0}}) 113 | 114 | for i in range(cfg.gefu.cas_config.num): 115 | rays, rgb, msk = gefu_utils.build_rays(tar_img, tar_ext, tar_ixt, tar_mask, i, self.split) 116 | s = cfg.gefu.cas_config.volume_scale[i] 117 | if self.split != 'train': # evaluation 118 | tar_dpt_i = cv2.resize(tar_dpt, None, fx=s, fy=s, interpolation=cv2.INTER_NEAREST) 119 | ret.update({f'tar_dpt_{i}': tar_dpt_i.astype(np.float32)}) 120 | ret.update({f'rays_{i}': rays, f'rgb_{i}': rgb.astype(np.float32), f'msk_{i}': msk}) 121 | ret['meta'].update({f'h_{i}': H, f'w_{i}': W}) 122 | 123 | if cfg.save_video: 124 | rendering_video_meta = [] 125 | render_path_mode = 'interpolate' 126 | poses_paths = self.get_video_rendering_path(ref_poses=src_exts, mode=render_path_mode, near_far=None, train_c2w_all=None, n_frames=60) 127 | for pose in poses_paths[0]: 128 | rendering_meta = { 129 | 'tar_ext': pose 130 | } 131 | for i in range(cfg.gefu.cas_config.num): 132 | tar_ext[:3] = pose 133 | rays, _, _ = gefu_utils.build_rays(tar_img, tar_ext, tar_ixt, tar_mask, i, self.split) 134 | rendering_meta.update({f'rays_{i}': rays}) 135 | rendering_video_meta.append(rendering_meta) 136 | ret['rendering_video_meta'] = rendering_video_meta 137 | return ret 138 | 139 | 140 | def get_video_rendering_path(self, ref_poses, mode, near_far, train_c2w_all, n_frames=60): 141 | poses_paths = [] 142 | ref_poses = ref_poses[None] 143 | for batch_idx, cur_src_poses in enumerate(ref_poses): 144 | if mode == 'interpolate': 145 | # convert to c2ws 146 | pose_square = torch.eye(4).unsqueeze(0).repeat(cur_src_poses.shape[0], 1, 1) 147 | cur_src_poses = torch.from_numpy(cur_src_poses) 148 | pose_square[:, :3, :] = cur_src_poses[:,:3] 149 | cur_c2ws = pose_square.double().inverse()[:, :3, :].to(torch.float32).cpu().detach().numpy() 150 | cur_path = get_interpolate_render_path(cur_c2ws, n_frames) 151 | elif mode == 'spiral': 152 | cur_c2ws_all = train_c2w_all 153 | cur_near_far = near_far.tolist() 154 | rads_scale = 0.3 155 | cur_path = get_spiral_render_path(cur_c2ws_all, cur_near_far, rads_scale=rads_scale, N_views=n_frames) 156 | else: 157 | raise Exception(f'Unknown video rendering path mode {mode}') 158 | 159 | # convert back to extrinsics tensor 160 | cur_w2cs = torch.tensor(cur_path).inverse()[:, :3].to(torch.float32) 161 | poses_paths.append(cur_w2cs) 162 | 163 | poses_paths = torch.stack(poses_paths, dim=0) 164 | return poses_paths 165 | 166 | def read_src(self, scene_info, src_views): 167 | inps, exts, ixts = [], [], [] 168 | for src_view in src_views: 169 | inps.append((np.array(imageio.imread(scene_info['img_paths'][src_view])) / 255.) * 2. - 1.) 170 | exts.append(scene_info['exts'][src_view]) 171 | ixts.append(scene_info['ixts'][src_view]) 172 | return np.stack(inps).transpose((0, 3, 1, 2)).astype(np.float32), np.stack(exts), np.stack(ixts) 173 | 174 | def __len__(self): 175 | return len(self.metas) 176 | 177 | -------------------------------------------------------------------------------- /lib/datasets/gefu_utils.py: -------------------------------------------------------------------------------- 1 | from lib.config import cfg 2 | import cv2 3 | import numpy as np 4 | 5 | def sample_patch(num_patch, patch_size, H, W, msk_sample): 6 | half_patch_size = patch_size // 2 7 | if msk_sample.sum() > 0: 8 | num_fg_patch = num_patch 9 | non_zero = msk_sample.nonzero() 10 | permutation = np.random.permutation(msk_sample.sum())[:num_fg_patch].astype(np.int32) 11 | X_, Y_ = non_zero[1][permutation], non_zero[0][permutation] 12 | X_ = np.clip(X_, half_patch_size, W-half_patch_size) 13 | Y_ = np.clip(Y_, half_patch_size, H-half_patch_size) 14 | else: 15 | num_fg_patch = 0 16 | num_patch = num_patch - num_fg_patch 17 | X = np.random.randint(low=half_patch_size, high=W-half_patch_size, size=num_patch) 18 | Y = np.random.randint(low=half_patch_size, high=H-half_patch_size, size=num_patch) 19 | if num_fg_patch > 0: 20 | X = np.concatenate([X, X_]).astype(np.int32) 21 | Y = np.concatenate([Y, Y_]).astype(np.int32) 22 | grid = np.meshgrid(np.arange(patch_size)-half_patch_size, np.arange(patch_size)-half_patch_size) 23 | return np.concatenate([grid[0].reshape(-1) + x for x in X]), np.concatenate([grid[1].reshape(-1) + y for y in Y]) 24 | 25 | def build_rays(tar_img, tar_ext, tar_ixt, tar_msk, level, split): 26 | scale = cfg.gefu.cas_config.render_scale[level] 27 | if scale != 1.: 28 | tar_img = cv2.resize(tar_img, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA) 29 | tar_msk = cv2.resize(tar_msk, None, fx=scale, fy=scale, interpolation=cv2.INTER_NEAREST) 30 | tar_ixt = tar_ixt.copy() 31 | tar_ixt[:2] *= scale 32 | H, W = tar_img.shape[:2] 33 | c2w = np.linalg.inv(tar_ext) 34 | if split == 'train' and not cfg.gefu.cas_config.train_img[level]: 35 | if cfg.gefu.sample_on_mask: # 313 36 | msk_sample = tar_msk 37 | num_fg_rays = int(min(cfg.gefu.cas_config.num_rays[level]*0.75, tar_msk.sum()*0.95)) 38 | non_zero = msk_sample.nonzero() 39 | permutation = np.random.permutation(tar_msk.sum())[:num_fg_rays].astype(np.int32) 40 | X_, Y_ = non_zero[1][permutation], non_zero[0][permutation] 41 | else: 42 | num_fg_rays = 0 43 | msk_sample = np.zeros_like(tar_msk) 44 | num_rays = cfg.gefu.cas_config.num_rays[level] - num_fg_rays 45 | X = np.random.randint(low=0, high=W, size=num_rays) 46 | Y = np.random.randint(low=0, high=H, size=num_rays) 47 | if num_fg_rays > 0: 48 | X = np.concatenate([X, X_]).astype(np.int32) 49 | Y = np.concatenate([Y, Y_]).astype(np.int32) 50 | if cfg.gefu.cas_config.num_patchs[level] > 0: 51 | X_, Y_ = sample_patch(cfg.gefu.cas_config.num_patchs[level], cfg.gefu.cas_config.patch_size[level], H, W, msk_sample) 52 | X = np.concatenate([X, X_]).astype(np.int32) 53 | Y = np.concatenate([Y, Y_]).astype(np.int32) 54 | num_rays = len(X) 55 | rays_o = c2w[:3, 3][None].repeat(num_rays, 0) 56 | XYZ = np.concatenate((X[:, None], Y[:, None], np.ones_like(X[:, None])), axis=-1) 57 | XYZ = XYZ @ (np.linalg.inv(tar_ixt).T @ c2w[:3, :3].T) 58 | rays = np.concatenate((rays_o, XYZ, X[..., None], Y[..., None]), axis=-1) 59 | rgb = tar_img[Y, X] 60 | msk = tar_msk[Y, X] 61 | else: 62 | rays_o = c2w[:3, 3][None, None] 63 | X, Y = np.meshgrid(np.arange(W), np.arange(H)) 64 | XYZ = np.concatenate((X[:, :, None], Y[:, :, None], np.ones_like(X[:, :, None])), axis=-1) 65 | XYZ = XYZ @ (np.linalg.inv(tar_ixt).T @ c2w[:3, :3].T) 66 | rays_o = rays_o.repeat(H, axis=0) 67 | rays_o = rays_o.repeat(W, axis=1) 68 | rays = np.concatenate((rays_o, XYZ, X[..., None], Y[..., None]), axis=-1) 69 | rgb = tar_img 70 | msk = tar_msk 71 | return rays.astype(np.float32).reshape(-1, 8), rgb.reshape(-1, 3), msk.reshape(-1) 72 | 73 | 74 | -------------------------------------------------------------------------------- /lib/datasets/llff/gefu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from lib.config import cfg 4 | import imageio 5 | import cv2 6 | import random 7 | from lib.config import cfg 8 | import torch 9 | from lib.datasets import gefu_utils 10 | from lib.datasets.video_utils import * 11 | 12 | 13 | class Dataset: 14 | def __init__(self, **kwargs): 15 | super(Dataset, self).__init__() 16 | self.data_root = os.path.join(cfg.workspace, kwargs['data_root']) 17 | self.split = kwargs['split'] 18 | self.input_h_w = kwargs['input_h_w'] 19 | if 'scene' in kwargs: 20 | self.scenes = [kwargs['scene']] 21 | else: 22 | self.scenes = [] 23 | self.build_metas() 24 | 25 | def build_metas(self): 26 | if len(self.scenes) == 0: 27 | scenes = ['fern', 'flower', 'fortress', 'horns', 'leaves', 'orchids', 'room', 'trex'] 28 | else: 29 | scenes = self.scenes 30 | self.scene_infos = {} 31 | self.metas = [] 32 | self.c2ws_all = {} 33 | pairs = torch.load('data/mvsnerf/pairs.th') 34 | for scene in scenes: 35 | 36 | pose_bounds = np.load(os.path.join(self.data_root, scene, 'poses_bounds.npy')) # c2w, -u, r, -t 37 | poses = pose_bounds[:, :15].reshape((-1, 3, 5)) 38 | c2ws = np.eye(4)[None].repeat(len(poses), 0) 39 | c2ws[:, :3, 0], c2ws[:, :3, 1], c2ws[:, :3, 2], c2ws[:, :3, 3] = poses[:, :3, 1], poses[:, :3, 0], -poses[:, :3, 2], poses[:, :3, 3] 40 | ixts = np.eye(3)[None].repeat(len(poses), 0) 41 | ixts[:, 0, 0], ixts[:, 1, 1] = poses[:, 2, 4], poses[:, 2, 4] 42 | ixts[:, 0, 2], ixts[:, 1, 2] = poses[:, 1, 4]/2., poses[:, 0, 4]/2. 43 | ixts[:, :2] *= 0.25 44 | 45 | img_paths = sorted([item for item in os.listdir(os.path.join(self.data_root, scene, 'images_4')) if '.png' in item]) 46 | depth_ranges = pose_bounds[:, -2:] 47 | scene_info = {'ixts': ixts.astype(np.float32), 'c2ws': c2ws.astype(np.float32), 'image_names': img_paths, 'depth_ranges': depth_ranges.astype(np.float32)} 48 | scene_info['scene_name'] = scene 49 | self.scene_infos[scene] = scene_info 50 | 51 | train_ids = pairs[f'{scene}_train'] 52 | if self.split == 'train': 53 | render_ids = train_ids 54 | else: 55 | render_ids = pairs[f'{scene}_val'] 56 | 57 | c2ws = np.stack(c2ws) 58 | self.c2ws_all[scene] = c2ws 59 | c2ws = c2ws[train_ids] 60 | for i in render_ids: 61 | c2w = scene_info['c2ws'][i] 62 | distance = np.linalg.norm((c2w[:3, 3][None] - c2ws[:, :3, 3]), axis=-1) 63 | argsorts = distance.argsort() 64 | argsorts = argsorts[1:] if i in train_ids else argsorts 65 | if self.split == 'train': 66 | src_views = [train_ids[i] for i in argsorts[:cfg.gefu.train_input_views[1]+1]] 67 | else: 68 | src_views = [train_ids[i] for i in argsorts[:cfg.gefu.test_input_views]] 69 | self.metas += [(scene, i, src_views)] 70 | 71 | def get_video_rendering_path(self, ref_poses, mode, near_far, train_c2w_all, n_frames=60): 72 | poses_paths = [] 73 | ref_poses = ref_poses[None] 74 | for batch_idx, cur_src_poses in enumerate(ref_poses): 75 | if mode == 'interpolate': 76 | # convert to c2ws 77 | pose_square = torch.eye(4).unsqueeze(0).repeat(cur_src_poses.shape[0], 1, 1) 78 | cur_src_poses = torch.from_numpy(cur_src_poses) 79 | pose_square[:, :3, :] = cur_src_poses[:,:3] 80 | cur_c2ws = pose_square.double().inverse()[:, :3, :].to(torch.float32).cpu().detach().numpy() 81 | cur_path = get_interpolate_render_path(cur_c2ws, n_frames) 82 | elif mode == 'spiral': 83 | cur_c2ws_all = train_c2w_all 84 | cur_near_far = near_far.tolist() 85 | rads_scale = 1 86 | cur_path = get_spiral_render_path(cur_c2ws_all, cur_near_far, rads_scale=rads_scale, N_views=n_frames) 87 | else: 88 | raise Exception(f'Unknown video rendering path mode {mode}') 89 | 90 | # convert back to extrinsics tensor 91 | cur_w2cs = torch.tensor(cur_path).inverse()[:, :3].to(torch.float32) 92 | poses_paths.append(cur_w2cs) 93 | 94 | poses_paths = torch.stack(poses_paths, dim=0) 95 | return poses_paths 96 | 97 | def __getitem__(self, index_meta): 98 | index, input_views_num = index_meta 99 | scene, tar_view, src_views = self.metas[index] 100 | if self.split == 'train': 101 | if np.random.random() < 0.1: 102 | src_views = src_views + [tar_view] 103 | src_views = random.sample(src_views, input_views_num) 104 | scene_info = self.scene_infos[scene] 105 | tar_img, tar_mask, tar_ext, tar_ixt = self.read_tar(scene_info, tar_view) 106 | src_inps, src_exts, src_ixts = self.read_src(scene_info, src_views) 107 | 108 | ret = {'src_inps': src_inps.transpose(0, 3, 1, 2), 109 | 'src_exts': src_exts, 110 | 'src_ixts': src_ixts} 111 | ret.update({'tar_ext': tar_ext, 112 | 'tar_ixt': tar_ixt}) 113 | if self.split != 'train': 114 | ret.update({'tar_img': tar_img, 115 | 'tar_mask': tar_mask}) 116 | 117 | H, W = tar_img.shape[:2] 118 | depth_ranges = np.array(scene_info['depth_ranges']) 119 | near_far = np.array([depth_ranges[:, 0].min().item(), depth_ranges[:, 1].max().item()]).astype(np.float32) 120 | # near_far = scene_info['depth_ranges'][tar_view] 121 | ret.update({'near_far': np.array(near_far).astype(np.float32)}) 122 | ret.update({'meta': {'scene': scene, 'tar_view': tar_view, 'frame_id': 0}}) 123 | 124 | for i in range(cfg.gefu.cas_config.num): 125 | rays, rgb, msk = gefu_utils.build_rays(tar_img, tar_ext, tar_ixt, tar_mask, i, self.split) 126 | ret.update({f'rays_{i}': rays, f'rgb_{i}': rgb.astype(np.float32), f'msk_{i}': msk}) 127 | s = cfg.gefu.cas_config.volume_scale[i] 128 | ret['meta'].update({f'h_{i}': int(H*s), f'w_{i}': int(W*s)}) 129 | if cfg.save_video: 130 | rendering_video_meta = [] 131 | render_path_mode = 'spiral' 132 | train_c2w_all = self.c2ws_all[scene][src_views] 133 | poses_paths = self.get_video_rendering_path(src_exts, render_path_mode, near_far, train_c2w_all, 60) 134 | for pose in poses_paths[0]: 135 | rendering_meta = { 136 | 'tar_ext': pose 137 | } 138 | for i in range(cfg.gefu.cas_config.num): 139 | tar_ext[:3] = pose 140 | rays, _, _ = gefu_utils.build_rays(tar_img, tar_ext, tar_ixt, tar_mask, i, self.split) 141 | rendering_meta.update({f'rays_{i}': rays}) 142 | rendering_video_meta.append(rendering_meta) 143 | ret['rendering_video_meta'] = rendering_video_meta 144 | return ret 145 | 146 | def read_src(self, scene, src_views): 147 | src_ids = src_views 148 | ixts, exts, imgs = [], [], [] 149 | for idx in src_ids: 150 | img, orig_size = self.read_image(scene, idx) 151 | imgs.append(((img/255.)*2-1).astype(np.float32)) 152 | ixt, ext, _ = self.read_cam(scene, idx, orig_size) 153 | ixts.append(ixt) 154 | exts.append(ext) 155 | return np.stack(imgs), np.stack(exts), np.stack(ixts) 156 | 157 | def read_tar(self, scene, view_idx): 158 | img, orig_size = self.read_image(scene, view_idx) 159 | img = (img/255.).astype(np.float32) 160 | ixt, ext, _ = self.read_cam(scene, view_idx, orig_size) 161 | mask = np.ones_like(img[..., 0]).astype(np.uint8) 162 | return img, mask, ext, ixt 163 | 164 | def read_cam(self, scene, view_idx, orig_size): 165 | ext = scene['c2ws'][view_idx].astype(np.float32) 166 | ixt = scene['ixts'][view_idx].copy() 167 | ixt[0] *= self.input_h_w[1] / orig_size[0] 168 | ixt[1] *= self.input_h_w[0] / orig_size[1] 169 | return ixt, np.linalg.inv(ext), 1 170 | 171 | def read_image(self, scene, view_idx): 172 | image_path = os.path.join(self.data_root, scene['scene_name'], 'images_4', scene['image_names'][view_idx]) 173 | img = (np.array(imageio.imread(image_path))).astype(np.float32) 174 | orig_size = img.shape[:2][::-1] 175 | img = cv2.resize(img, self.input_h_w[::-1], interpolation=cv2.INTER_AREA) 176 | return np.array(img), orig_size 177 | 178 | def __len__(self): 179 | return len(self.metas) 180 | 181 | def get_K_from_params(params): 182 | K = np.zeros((3, 3)).astype(np.float32) 183 | K[0][0], K[0][2], K[1][2] = params[:3] 184 | K[1][1] = K[0][0] 185 | K[2][2] = 1. 186 | return K 187 | 188 | -------------------------------------------------------------------------------- /lib/datasets/make_dataset.py: -------------------------------------------------------------------------------- 1 | from . import samplers 2 | import torch 3 | import torch.utils.data 4 | import imp 5 | import os 6 | from .collate_batch import make_collator 7 | import numpy as np 8 | import time 9 | from lib.config.config import cfg 10 | from torch.utils.data import DataLoader, ConcatDataset 11 | import cv2 12 | cv2.setNumThreads(1) 13 | 14 | 15 | # torch.multiprocessing.set_sharing_strategy('file_system') 16 | 17 | def _dataset_factory(is_train, is_val): 18 | if is_val: 19 | module = cfg.val_dataset_module 20 | path = cfg.val_dataset_path 21 | elif is_train: 22 | module = cfg.train_dataset_module 23 | path = cfg.train_dataset_path 24 | else: 25 | module = cfg.test_dataset_module 26 | path = cfg.test_dataset_path 27 | dataset = imp.load_source(module, path).Dataset 28 | return dataset 29 | 30 | 31 | def make_dataset(cfg, is_train=True): 32 | if is_train: 33 | args = cfg.train_dataset 34 | module = cfg.train_dataset_module 35 | path = cfg.train_dataset_path 36 | else: 37 | args = cfg.test_dataset 38 | module = cfg.test_dataset_module 39 | path = cfg.test_dataset_path 40 | dataset = imp.load_source(module, path).Dataset 41 | dataset = dataset(**args) 42 | return dataset 43 | 44 | 45 | def make_data_sampler(dataset, shuffle, is_distributed): 46 | if is_distributed: 47 | return samplers.DistributedSampler(dataset, shuffle=shuffle) 48 | if shuffle: 49 | sampler = torch.utils.data.sampler.RandomSampler(dataset) 50 | else: 51 | sampler = torch.utils.data.sampler.SequentialSampler(dataset) 52 | return sampler 53 | 54 | 55 | def make_batch_data_sampler(cfg, sampler, batch_size, drop_last, max_iter, 56 | is_train): 57 | if is_train: 58 | batch_sampler = cfg.train.batch_sampler 59 | sampler_meta = cfg.train.sampler_meta 60 | else: 61 | batch_sampler = cfg.test.batch_sampler 62 | sampler_meta = cfg.test.sampler_meta 63 | if batch_sampler == 'default': 64 | batch_sampler = torch.utils.data.sampler.BatchSampler( 65 | sampler, batch_size, drop_last) 66 | elif batch_sampler == 'image_size': 67 | batch_sampler = samplers.ImageSizeBatchSampler(sampler, batch_size, 68 | drop_last, sampler_meta) 69 | elif batch_sampler == 'gefu': 70 | batch_sampler = samplers.GefuBatchSampler(sampler, batch_size, drop_last, sampler_meta) 71 | if max_iter != -1: 72 | batch_sampler = samplers.IterationBasedBatchSampler( 73 | batch_sampler, max_iter) 74 | return batch_sampler 75 | 76 | 77 | def worker_init_fn(worker_id): 78 | np.random.seed(worker_id + (int(round(time.time() * 1000) % (2**16)))) 79 | 80 | 81 | def make_data_loader(cfg, is_train=True, is_distributed=False, max_iter=-1): 82 | if is_train: 83 | batch_size = cfg.train.batch_size 84 | # shuffle = True 85 | shuffle = cfg.train.shuffle 86 | drop_last = False 87 | else: 88 | batch_size = cfg.test.batch_size 89 | shuffle = True if is_distributed else False 90 | drop_last = False 91 | 92 | dataset = make_dataset(cfg, is_train) 93 | sampler = make_data_sampler(dataset, shuffle, is_distributed) 94 | batch_sampler = make_batch_data_sampler(cfg, sampler, batch_size, 95 | drop_last, max_iter, is_train) 96 | num_workers = cfg.train.num_workers 97 | collator = make_collator(cfg, is_train) 98 | data_loader = DataLoader(dataset, 99 | batch_sampler=batch_sampler, 100 | num_workers=num_workers, 101 | collate_fn=collator, 102 | worker_init_fn=worker_init_fn) 103 | 104 | return data_loader 105 | -------------------------------------------------------------------------------- /lib/datasets/nerf/gefu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from lib.config import cfg 4 | import imageio 5 | import cv2 6 | import random 7 | from lib.config import cfg 8 | import torch 9 | import json 10 | from lib.datasets import gefu_utils 11 | from lib.datasets.video_utils import * 12 | 13 | class Dataset: 14 | def __init__(self, **kwargs): 15 | super(Dataset, self).__init__() 16 | self.data_root = os.path.join(cfg.workspace, kwargs['data_root']) 17 | self.split = kwargs['split'] 18 | self.input_h_w = kwargs['input_h_w'] 19 | if 'scene' in kwargs: 20 | self.scenes = [kwargs['scene']] 21 | else: 22 | self.scenes = [] 23 | self.build_metas() 24 | 25 | def build_metas(self): 26 | if len(self.scenes) == 0: 27 | scenes = ['chair', 'drums', 'ficus', 'hotdog', 'lego', 'materials', 'mic', 'ship'] 28 | else: 29 | scenes = self.scenes 30 | self.scene_infos = {} 31 | self.metas = [] 32 | pairs = torch.load('data/mvsnerf/pairs.th') 33 | for scene in scenes: 34 | json_info = json.load(open(os.path.join(self.data_root, scene,'transforms_train.json'))) 35 | b2c = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 36 | scene_info = {'ixts': [], 'exts': [], 'img_paths': []} 37 | for idx in range(len(json_info['frames'])): 38 | c2w = np.array(json_info['frames'][idx]['transform_matrix']) 39 | c2w = c2w @ b2c 40 | ext = np.linalg.inv(c2w) 41 | ixt = np.eye(3) 42 | ixt[0][2], ixt[1][2] = 400., 400. 43 | focal = .5 * 800 / np.tan(.5 * json_info['camera_angle_x']) 44 | ixt[0][0], ixt[1][1] = focal, focal 45 | scene_info['ixts'].append(ixt.astype(np.float32)) 46 | scene_info['exts'].append(ext.astype(np.float32)) 47 | img_path = os.path.join(self.data_root, scene, 'train/r_{}.png'.format(idx)) 48 | scene_info['img_paths'].append(img_path) 49 | self.scene_infos[scene] = scene_info 50 | train_ids, render_ids = pairs[f'{scene}_train'], pairs[f'{scene}_val'] 51 | if self.split == 'train': 52 | render_ids = train_ids 53 | c2ws = np.stack([np.linalg.inv(scene_info['exts'][idx]) for idx in train_ids]) 54 | for idx in render_ids: 55 | c2w = np.linalg.inv(scene_info['exts'][idx]) 56 | distance = np.linalg.norm((c2w[:3, 3][None] - c2ws[:, :3, 3]), axis=-1) 57 | 58 | argsorts = distance.argsort() 59 | argsorts = argsorts[1:] if idx in train_ids else argsorts 60 | 61 | input_views_num = cfg.gefu.train_input_views[1] + 1 if self.split == 'train' else cfg.gefu.test_input_views 62 | src_views = [train_ids[i] for i in argsorts[:input_views_num]] 63 | self.metas += [(scene, idx, src_views)] 64 | 65 | def get_video_rendering_path(self, ref_poses, mode, near_far, train_c2w_all, n_frames=60): 66 | poses_paths = [] 67 | ref_poses = ref_poses[None] 68 | for batch_idx, cur_src_poses in enumerate(ref_poses): 69 | if mode == 'interpolate': 70 | # convert to c2ws 71 | pose_square = torch.eye(4).unsqueeze(0).repeat(cur_src_poses.shape[0], 1, 1) 72 | cur_src_poses = torch.from_numpy(cur_src_poses) 73 | pose_square[:, :3, :] = cur_src_poses[:,:3] 74 | cur_c2ws = pose_square.double().inverse()[:, :3, :].to(torch.float32).cpu().detach().numpy() 75 | cur_path = get_interpolate_render_path(cur_c2ws, n_frames) 76 | elif mode == 'spiral': 77 | cur_c2ws_all = train_c2w_all 78 | cur_near_far = near_far.tolist() 79 | rads_scale = 0.3 80 | cur_path = get_spiral_render_path(cur_c2ws_all, cur_near_far, rads_scale=rads_scale, N_views=n_frames) 81 | else: 82 | raise Exception(f'Unknown video rendering path mode {mode}') 83 | 84 | # convert back to extrinsics tensor 85 | cur_w2cs = torch.tensor(cur_path).inverse()[:, :3].to(torch.float32) 86 | poses_paths.append(cur_w2cs) 87 | 88 | poses_paths = torch.stack(poses_paths, dim=0) 89 | return poses_paths 90 | 91 | 92 | def __getitem__(self, index_meta): 93 | index, input_views_num = index_meta 94 | scene, tar_view, src_views = self.metas[index] 95 | if self.split == 'train': 96 | if np.random.random() < 0.1: 97 | src_views = src_views + [tar_view] 98 | src_views = random.sample(src_views, input_views_num) 99 | scene_info = self.scene_infos[scene] 100 | scene_info['scene_name'] = scene 101 | tar_img, tar_ext, tar_ixt = self.read_tar(scene_info, tar_view) 102 | src_inps, src_exts, src_ixts = self.read_src(scene_info, src_views) 103 | 104 | ret = {'src_inps': src_inps.transpose(0, 3, 1, 2), 105 | 'src_exts': src_exts, 106 | 'src_ixts': src_ixts} 107 | tar_mask = np.ones_like(tar_img[..., 0]).astype(np.uint8) 108 | H, W = tar_img.shape[:2] 109 | ret.update({'tar_ext': tar_ext, 110 | 'tar_ixt': tar_ixt}) 111 | if self.split != 'train': 112 | ret.update({'tar_img': tar_img, 113 | 'tar_mask': tar_mask}) 114 | near_far = np.array([2.5, 5.5]).astype(np.float32) 115 | ret.update({'near_far': np.array(near_far).astype(np.float32)}) 116 | ret.update({'meta': {'scene': scene, 'tar_view': tar_view, 'frame_id': 0}}) 117 | 118 | for i in range(cfg.gefu.cas_config.num): 119 | rays, rgb, msk = gefu_utils.build_rays(tar_img, tar_ext, tar_ixt, tar_mask, i, self.split) 120 | ret.update({f'rays_{i}': rays, f'rgb_{i}': rgb.astype(np.float32), f'msk_{i}': msk}) 121 | s = cfg.gefu.cas_config.volume_scale[i] 122 | ret['meta'].update({f'h_{i}': int(H*s), f'w_{i}': int(W*s)}) 123 | if cfg.save_video: 124 | rendering_video_meta = [] 125 | render_path_mode = 'interpolate' 126 | poses_paths = self.get_video_rendering_path(ref_poses=src_exts, mode=render_path_mode, near_far=None, train_c2w_all=None, n_frames=60) 127 | for pose in poses_paths[0]: 128 | rendering_meta = { 129 | 'tar_ext': pose 130 | } 131 | for i in range(cfg.gefu.cas_config.num): 132 | tar_ext[:3] = pose 133 | rays, _, _ = gefu_utils.build_rays(tar_img, tar_ext, tar_ixt, tar_mask, i, self.split) 134 | rendering_meta.update({f'rays_{i}': rays}) 135 | rendering_video_meta.append(rendering_meta) 136 | ret['rendering_video_meta'] = rendering_video_meta 137 | return ret 138 | 139 | def read_src(self, scene, src_views): 140 | src_ids = src_views 141 | ixts, exts, imgs = [], [], [] 142 | for idx in src_ids: 143 | img, orig_size = self.read_image(scene, idx) 144 | imgs.append((img*2-1).astype(np.float32)) 145 | ixt, ext = self.read_cam(scene, idx, orig_size) 146 | ixts.append(ixt) 147 | exts.append(ext) 148 | return np.stack(imgs), np.stack(exts), np.stack(ixts) 149 | 150 | def read_tar(self, scene, view_idx): 151 | img, orig_size = self.read_image(scene, view_idx) 152 | ixt, ext = self.read_cam(scene, view_idx, orig_size) 153 | return img, ext, ixt 154 | 155 | def read_cam(self, scene, view_idx, orig_size): 156 | ext = scene['exts'][view_idx] 157 | ixt = scene['ixts'][view_idx].copy() 158 | ixt[0] *= self.input_h_w[1] / orig_size[0] 159 | ixt[1] *= self.input_h_w[0] / orig_size[1] 160 | return ixt, ext 161 | 162 | def read_image(self, scene, view_idx): 163 | img_path = scene['img_paths'][view_idx] 164 | img = (np.array(imageio.imread(img_path)) / 255.).astype(np.float32) 165 | img = (img[..., :3] * img[..., -1:] + (1 - img[..., -1:])).astype(np.float32) 166 | orig_size = img.shape[:2][::-1] 167 | img = cv2.resize(img, self.input_h_w[::-1], interpolation=cv2.INTER_AREA) 168 | return img, orig_size 169 | 170 | def __len__(self): 171 | return len(self.metas) 172 | 173 | def get_K_from_params(params): 174 | K = np.zeros((3, 3)).astype(np.float32) 175 | K[0][0], K[0][2], K[1][2] = params[:3] 176 | K[1][1] = K[0][0] 177 | K[2][2] = 1. 178 | return K 179 | 180 | -------------------------------------------------------------------------------- /lib/datasets/samplers.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | from torch.utils.data.sampler import BatchSampler 3 | import numpy as np 4 | import torch 5 | import math 6 | import torch.distributed as dist 7 | from lib.config import cfg 8 | 9 | class GefuBatchSampler(Sampler): 10 | def __init__(self, sampler, batch_size, drop_last, sampler_meta): 11 | self.sampler = sampler 12 | self.batch_size = batch_size 13 | self.drop_last = drop_last 14 | self.input_views = sampler_meta.input_views_num 15 | self.views_prob = sampler_meta.input_views_prob 16 | if cfg.fix_random: 17 | random.seed(0) 18 | 19 | def __iter__(self): 20 | batch = [] 21 | input_views_num = np.random.choice(self.input_views, 1, p=self.views_prob) 22 | for idx in self.sampler: 23 | batch.append((idx, input_views_num.item())) 24 | if len(batch) == self.batch_size: 25 | input_views_num = np.random.choice(self.input_views, 1, p=self.views_prob) 26 | yield batch 27 | batch = [] 28 | if len(batch) > 0 and not self.drop_last: 29 | yield batch 30 | 31 | def __len__(self): 32 | if self.drop_last: 33 | return len(self.sampler) // self.batch_size 34 | else: 35 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 36 | 37 | 38 | class ImageSizeBatchSampler(Sampler): 39 | def __init__(self, sampler, batch_size, drop_last, sampler_meta): 40 | self.sampler = sampler 41 | self.batch_size = batch_size 42 | self.drop_last = drop_last 43 | self.strategy = sampler_meta.strategy 44 | self.hmin, self.wmin = sampler_meta.min_hw 45 | self.hmax, self.wmax = sampler_meta.max_hw 46 | self.divisor = 32 47 | if cfg.fix_random: 48 | np.random.seed(0) 49 | 50 | def generate_height_width(self): 51 | if self.strategy == 'origin': 52 | return -1, -1 53 | h = np.random.randint(self.hmin, self.hmax + 1) 54 | w = np.random.randint(self.wmin, self.wmax + 1) 55 | h = (h | (self.divisor - 1)) + 1 56 | w = (w | (self.divisor - 1)) + 1 57 | return h, w 58 | 59 | def __iter__(self): 60 | batch = [] 61 | h, w = self.generate_height_width() 62 | for idx in self.sampler: 63 | batch.append((idx, h, w)) 64 | if len(batch) == self.batch_size: 65 | h, w = self.generate_height_width() 66 | yield batch 67 | batch = [] 68 | if len(batch) > 0 and not self.drop_last: 69 | yield batch 70 | 71 | def __len__(self): 72 | if self.drop_last: 73 | return len(self.sampler) // self.batch_size 74 | else: 75 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 76 | 77 | 78 | class IterationBasedBatchSampler(BatchSampler): 79 | """ 80 | Wraps a BatchSampler, resampling from it until 81 | a specified number of iterations have been sampled 82 | """ 83 | 84 | def __init__(self, batch_sampler, num_iterations, start_iter=0): 85 | self.batch_sampler = batch_sampler 86 | self.sampler = self.batch_sampler.sampler 87 | self.num_iterations = num_iterations 88 | self.start_iter = start_iter 89 | 90 | def __iter__(self): 91 | iteration = self.start_iter 92 | while iteration <= self.num_iterations: 93 | for batch in self.batch_sampler: 94 | iteration += 1 95 | if iteration > self.num_iterations: 96 | break 97 | yield batch 98 | 99 | def __len__(self): 100 | return self.num_iterations 101 | 102 | 103 | class DistributedSampler(Sampler): 104 | """Sampler that restricts data loading to a subset of the dataset. 105 | It is especially useful in conjunction with 106 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 107 | process can pass a DistributedSampler instance as a DataLoader sampler, 108 | and load a subset of the original dataset that is exclusive to it. 109 | .. note:: 110 | Dataset is assumed to be of constant size. 111 | Arguments: 112 | dataset: Dataset used for sampling. 113 | num_replicas (optional): Number of processes participating in 114 | distributed training. 115 | rank (optional): Rank of the current process within num_replicas. 116 | """ 117 | 118 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 119 | if num_replicas is None: 120 | if not dist.is_available(): 121 | raise RuntimeError("Requires distributed package to be available") 122 | num_replicas = dist.get_world_size() 123 | if rank is None: 124 | if not dist.is_available(): 125 | raise RuntimeError("Requires distributed package to be available") 126 | rank = dist.get_rank() 127 | self.dataset = dataset 128 | self.num_replicas = num_replicas 129 | self.rank = rank 130 | self.epoch = 0 131 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 132 | self.total_size = self.num_samples * self.num_replicas 133 | self.shuffle = shuffle 134 | 135 | def __iter__(self): 136 | if self.shuffle: 137 | # deterministically shuffle based on epoch 138 | g = torch.Generator() 139 | g.manual_seed(self.epoch) 140 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 141 | else: 142 | indices = torch.arange(len(self.dataset)).tolist() 143 | 144 | # add extra samples to make it evenly divisible 145 | indices += indices[: (self.total_size - len(indices))] 146 | assert len(indices) == self.total_size 147 | 148 | # subsample 149 | offset = self.num_samples * self.rank 150 | indices = indices[offset:offset+self.num_samples] 151 | assert len(indices) == self.num_samples 152 | 153 | return iter(indices) 154 | 155 | def __len__(self): 156 | return self.num_samples 157 | 158 | def set_epoch(self, epoch): 159 | self.epoch = epoch 160 | -------------------------------------------------------------------------------- /lib/datasets/video_utils.py: -------------------------------------------------------------------------------- 1 | from lib.config import cfg 2 | import cv2 3 | import numpy as np 4 | from scipy.spatial.transform import Rotation 5 | 6 | 7 | def get_interpolate_render_path(c2ws, N_views=30): 8 | N = len(c2ws) 9 | rotvec, positions = [], [] 10 | rotvec_inteplat, positions_inteplat = [], [] 11 | weight = np.linspace(1.0, .0, N_views//3, endpoint=False).reshape(-1, 1) 12 | for i in range(N): 13 | r = Rotation.from_matrix(c2ws[i, :3, :3]) 14 | euler_ange = r.as_euler('xyz', degrees=True).reshape(1, 3) 15 | if i: 16 | mask = np.abs(euler_ange - rotvec[0]) > 180 17 | euler_ange[mask] += 360.0 18 | rotvec.append(euler_ange) 19 | positions.append(c2ws[i, :3, 3:].reshape(1, 3)) 20 | 21 | if i: 22 | rotvec_inteplat.append(weight * rotvec[i - 1] + (1.0 - weight) * rotvec[i]) 23 | positions_inteplat.append(weight * positions[i - 1] + (1.0 - weight) * positions[i]) 24 | 25 | rotvec_inteplat.append(weight * rotvec[-1] + (1.0 - weight) * rotvec[0]) 26 | positions_inteplat.append(weight * positions[-1] + (1.0 - weight) * positions[0]) 27 | 28 | c2ws_render = [] 29 | angles_inteplat, positions_inteplat = np.concatenate(rotvec_inteplat), np.concatenate(positions_inteplat) 30 | for rotvec, position in zip(angles_inteplat, positions_inteplat): 31 | c2w = np.eye(4) 32 | c2w[:3, :3] = Rotation.from_euler('xyz', rotvec, degrees=True).as_matrix() 33 | c2w[:3, 3:] = position.reshape(3, 1) 34 | c2ws_render.append(c2w.copy()) 35 | c2ws_render = np.stack(c2ws_render) 36 | return c2ws_render 37 | 38 | 39 | 40 | 41 | 42 | def normalize(v): 43 | """Normalize a vector.""" 44 | return v / np.linalg.norm(v) 45 | 46 | 47 | def average_poses(poses): 48 | """ 49 | Calculate the average pose, which is then used to center all poses 50 | using @center_poses. Its computation is as follows: 51 | 1. Compute the center: the average of pose centers. 52 | 2. Compute the z axis: the normalized average z axis. 53 | 3. Compute axis y': the average y axis. 54 | 4. Compute x' = y' cross product z, then normalize it as the x axis. 55 | 5. Compute the y axis: z cross product x. 56 | 57 | Note that at step 3, we cannot directly use y' as y axis since it's 58 | not necessarily orthogonal to z axis. We need to pass from x to y. 59 | Inputs: 60 | poses: (N_images, 3, 4) 61 | Outputs: 62 | pose_avg: (3, 4) the average pose 63 | """ 64 | # 1. Compute the center 65 | center = poses[..., 3].mean(0) # (3) 66 | # 2. Compute the z axis 67 | z = normalize(poses[..., 2].mean(0)) # (3) 68 | # 3. Compute axis y' (no need to normalize as it's not the final output) 69 | y_ = poses[..., 1].mean(0) # (3) 70 | # 4. Compute the x axis 71 | x = normalize(np.cross(y_, z)) # (3) 72 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) 73 | y = np.cross(z, x) # (3) 74 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4) 75 | return pose_avg 76 | 77 | def center_poses(poses, blender2opencv): 78 | """ 79 | Center the poses so that we can use NDC. 80 | See https://github.com/bmild/nerf/issues/34 81 | Inputs: 82 | poses: (N_images, 3, 4) 83 | Outputs: 84 | poses_centered: (N_images, 3, 4) the centered poses 85 | pose_avg: (3, 4) the average pose 86 | """ 87 | pose_avg = average_poses(poses) # (3, 4) 88 | pose_avg_homo = np.eye(4) 89 | pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation 90 | # by simply adding 0, 0, 0, 1 as the last row 91 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4) 92 | poses_homo = np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate 93 | 94 | poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4) 95 | poses_centered = poses_centered @ blender2opencv 96 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4) 97 | 98 | return poses_centered 99 | 100 | # --------------------- Start: Render scriptes for LLFF dataset video generation -------------------- 101 | def get_spiral_render_path(c2ws_all, near_far, rads_scale=0.5, N_views=120): 102 | # center pose 103 | c2w = poses_avg(c2ws_all) 104 | 105 | # Get average pose 106 | up = normalize(c2ws_all[:, :3, 1].sum(0)) 107 | 108 | # Find a reasonable "focus depth" for this dataset 109 | close_depth, inf_depth = near_far 110 | # print(near_far) 111 | dt = .75 112 | mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth)) 113 | focal = mean_dz 114 | # print(focal) 115 | # Get radii for spiral path 116 | shrink_factor = .8 117 | zdelta = close_depth * .2 118 | tt = c2ws_all[:, :3, 3] - c2w[:3, 3][None] 119 | rads = np.percentile(np.abs(tt), 70, 0)*rads_scale 120 | render_poses = render_path_spiral(c2w, up, rads, focal, zdelta, zrate=.5, N=N_views) 121 | return np.stack(render_poses) 122 | 123 | def poses_avg(poses): 124 | center = poses[:, :3, 3].mean(0) 125 | vec2 = normalize(poses[:, :3, 2].sum(0)) 126 | up = poses[:, :3, 1].sum(0) 127 | c2w = viewmatrix(vec2, up, center) 128 | 129 | return c2w 130 | 131 | 132 | def normalize(x): 133 | return x / np.linalg.norm(x, axis=-1, keepdims=True) 134 | 135 | 136 | def viewmatrix(z, up, pos): 137 | vec2 = normalize(z) 138 | vec1_avg = up 139 | vec0 = normalize(np.cross(vec1_avg, vec2)) 140 | vec1 = normalize(np.cross(vec2, vec0)) 141 | m = np.eye(4) 142 | m[:3] = np.stack([vec0, vec1, vec2, pos], 1) 143 | return m 144 | 145 | 146 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=120): 147 | render_poses = [] 148 | rads = np.array(list(rads) + [1.]) 149 | 150 | for theta in np.linspace(0., 2. * np.pi * N_rots, N+1)[:-1]: 151 | c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads) 152 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.]))) 153 | render_poses.append(viewmatrix(z, up, c)) 154 | return render_poses 155 | # --------------------- End: Render scriptes for LLFF dataset video generation -------------------- -------------------------------------------------------------------------------- /lib/evaluators/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_evaluator import make_evaluator 2 | -------------------------------------------------------------------------------- /lib/evaluators/gefu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from lib.config import cfg 3 | import os 4 | import sys 5 | import imageio 6 | from lib.utils import img_utils 7 | from skimage.metrics import structural_similarity as ssim 8 | from skimage.metrics import peak_signal_noise_ratio as psnr 9 | import torch.nn.functional as F 10 | import torch 11 | import lpips 12 | import imageio 13 | from lib.utils import img_utils 14 | import cv2 15 | import torchvision.transforms as T 16 | from PIL import Image 17 | import matplotlib.pyplot as plt 18 | 19 | 20 | def write_cam(file, ixt, ext): 21 | f = open(file, "w") 22 | f.write('extrinsic\n') 23 | for i in range(0, 4): 24 | for j in range(0, 4): 25 | f.write(str(ext[i][j]) + ' ') 26 | f.write('\n') 27 | f.write('\n') 28 | 29 | f.write('intrinsic\n') 30 | for i in range(0, 3): 31 | for j in range(0, 3): 32 | f.write(str(ixt[i][j]) + ' ') 33 | f.write('\n') 34 | 35 | f.close() 36 | 37 | 38 | def save_pfm(filename, image, scale=1): 39 | file = open(filename, "wb") 40 | color = None 41 | 42 | image = np.flipud(image) 43 | 44 | if image.dtype.name != 'float32': 45 | raise Exception('Image dtype must be float32.') 46 | 47 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 48 | color = True 49 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale 50 | color = False 51 | else: 52 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') 53 | 54 | file.write('PF\n'.encode('utf-8') if color else 'Pf\n'.encode('utf-8')) 55 | file.write('{} {}\n'.format(image.shape[1], image.shape[0]).encode('utf-8')) 56 | 57 | endian = image.dtype.byteorder 58 | 59 | if endian == '<' or endian == '=' and sys.byteorder == 'little': 60 | scale = -scale 61 | 62 | file.write(('%f\n' % scale).encode('utf-8')) 63 | 64 | image.tofile(file) 65 | file.close() 66 | 67 | 68 | def unpreprocess(data, shape=(1, 1, 3, 1, 1), render_scale=1.): 69 | device = data.device 70 | # mean = torch.tensor([-0.485/0.229, -0.456/0.224, -0.406 / 0.225]).view(*shape).to(device) 71 | # std = torch.tensor([1 / 0.229, 1 / 0.224, 1 / 0.225]).view(*shape).to(device) 72 | img = data * 0.5 + 0.5 73 | B, S, C, H, W = img.shape 74 | img = F.interpolate(img.reshape(B*S, C, H, W), scale_factor=render_scale, align_corners=True, mode='bilinear', recompute_scale_factor=True).reshape(B, S, C, int(H*render_scale), int(W*render_scale)) 75 | return img 76 | 77 | def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET): 78 | if type(depth) is not np.ndarray: 79 | depth = depth.cpu().numpy() 80 | 81 | x = np.nan_to_num(depth) # change nan to 0 82 | if minmax is None: 83 | mi = np.min(x[x > 0]) # get minimum positive depth (ignore background) 84 | ma = np.max(x) 85 | else: 86 | mi, ma = minmax 87 | 88 | x = (x - mi) / (ma - mi + 1e-8) # normalize to 0~1 89 | x = (255 * x).astype(np.uint8) 90 | x_ = Image.fromarray(cv2.applyColorMap(x, cmap)) 91 | x_ = T.ToTensor()(x_) # (3, H, W) 92 | return x_, [mi, ma] 93 | 94 | 95 | 96 | class Evaluator: 97 | 98 | def __init__(self,): 99 | self.psnrs = [] 100 | self.ssims = [] 101 | self.lpips = [] 102 | self.scene_psnrs = {} 103 | self.scene_ssims = {} 104 | self.scene_lpips = {} 105 | self.loss_fn_vgg = lpips.LPIPS(net='vgg') 106 | self.loss_fn_vgg.cuda() 107 | if cfg.gefu.eval_depth: 108 | # Following the setup of MVSNeRF 109 | self.eval_depth_scenes = ['scan1', 'scan8', 'scan21', 'scan103', 'scan110'] 110 | self.abs = [] 111 | self.acc_2 = [] 112 | self.acc_10 = [] 113 | self.mvs_abs = [] 114 | self.mvs_acc_2 = [] 115 | self.mvs_acc_10 = [] 116 | os.system('mkdir -p ' + cfg.result_dir) 117 | 118 | def evaluate(self, output, batch): 119 | B, S, _, H, W = batch['src_inps'].shape 120 | for i in range(cfg.gefu.cas_config.num): 121 | if not cfg.gefu.cas_config.render_if[i]: 122 | continue 123 | render_scale = cfg.gefu.cas_config.render_scale[i] 124 | h, w = int(H*render_scale), int(W*render_scale) 125 | pred_rgb = output[f'rgb_level{i}'].reshape(B, h, w, 3).detach().cpu().numpy() 126 | gt_rgb = batch[f'rgb_{i}'].reshape(B, h, w, 3).detach().cpu().numpy() 127 | masks = (batch[f'msk_{i}'].reshape(B, h, w).cpu().numpy() >= 1).astype(np.uint8) 128 | if i == cfg.gefu.cas_config.num-1: 129 | pred_rgb_b = output[f'rgb_b_level{i}'].reshape(B, h, w, 3).detach().cpu().numpy() 130 | pred_rgb_r = output[f'rgb_r_level{i}'].reshape(B, h, w, 3).detach().cpu().numpy() 131 | 132 | if cfg.gefu.eval_center: 133 | H_crop, W_crop = int(h*0.1), int(w*0.1) 134 | pred_rgb = pred_rgb[:, H_crop:-H_crop, W_crop:-W_crop] 135 | gt_rgb = gt_rgb[:, H_crop:-H_crop, W_crop:-W_crop] 136 | masks = masks[:, H_crop:-H_crop, W_crop:-W_crop] 137 | if i == cfg.gefu.cas_config.num-1: 138 | pred_rgb_b = pred_rgb_b[:, H_crop:-H_crop, W_crop:-W_crop] 139 | pred_rgb_r = pred_rgb_r[:, H_crop:-H_crop, W_crop:-W_crop] 140 | for b in range(B): 141 | if not batch['meta']['scene'][b]+f'_level{i}' in self.scene_psnrs: 142 | self.scene_psnrs[batch['meta']['scene'][b]+f'_level{i}'] = [] 143 | self.scene_ssims[batch['meta']['scene'][b]+f'_level{i}'] = [] 144 | self.scene_lpips[batch['meta']['scene'][b]+f'_level{i}'] = [] 145 | if cfg.save_result and i == 1: 146 | img = img_utils.horizon_concate(gt_rgb[b], pred_rgb[b]) 147 | img_path = os.path.join(cfg.result_dir, '{}_{}_{}.png'.format(batch['meta']['scene'][b], batch['meta']['tar_view'][b].item(), batch['meta']['frame_id'][b].item())) 148 | imageio.imwrite(img_path, (img*255.).astype(np.uint8)) 149 | 150 | if cfg.save_ply: 151 | dataset_name = cfg.train_dataset_module.split('.')[-2] 152 | ply_dir = os.path.join(cfg.result_dir, 'pointclouds', dataset_name) 153 | os.makedirs(ply_dir, exist_ok = True) 154 | scan_dir = os.path.join(ply_dir, batch['meta']['scene'][0]) 155 | os.makedirs(scan_dir, exist_ok = True) 156 | img_dir = os.path.join(scan_dir, 'images') 157 | os.makedirs(img_dir, exist_ok = True) 158 | img_path = os.path.join(img_dir, '{}_{}_{}.png'.format(batch['meta']['scene'][0], batch['meta']['tar_view'][0].item(), batch['meta']['frame_id'][0].item())) 159 | img = output[f'rgb_level{i}'].reshape(B, h, w, 3).detach().cpu().numpy() 160 | img = (img[b]*255).astype(np.uint8) 161 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 162 | cv2.imwrite(img_path, img) 163 | 164 | cam_dir = os.path.join(scan_dir, 'cam') 165 | os.makedirs(cam_dir, exist_ok = True) 166 | cam_path = os.path.join(cam_dir, '{}_{}_{}.txt'.format(batch['meta']['scene'][0], batch['meta']['tar_view'][0].item(), batch['meta']['frame_id'][0].item())) 167 | 168 | ixt = batch['tar_ixt'].detach().cpu().numpy()[0] 169 | ext = batch['tar_ext'].detach().cpu().numpy()[0] 170 | write_cam(cam_path, ixt, ext) 171 | 172 | nerf_depth = output['depth_level1'].cpu().numpy()[b].reshape((h, w)) 173 | depth = nerf_depth 174 | depth_dir = os.path.join(scan_dir, 'depth') 175 | os.makedirs(depth_dir, exist_ok = True) 176 | depth_path = os.path.join(depth_dir, '{}_{}_{}.pfm'.format(batch['meta']['scene'][0], batch['meta']['tar_view'][0].item(), batch['meta']['frame_id'][0].item())) 177 | save_pfm(depth_path, depth) 178 | 179 | depth_minmax = [ 180 | batch["near_far"].min().detach().cpu().numpy(), 181 | batch["near_far"].max().detach().cpu().numpy(), 182 | ] 183 | rendered_depth_vis, _ = visualize_depth(depth, depth_minmax) 184 | rendered_depth_vis = rendered_depth_vis.permute(1,2,0).detach().cpu().numpy() 185 | depth_vis_path = os.path.join(depth_dir, '{}_{}_{}.png'.format(batch['meta']['scene'][0], batch['meta']['tar_view'][0].item(), batch['meta']['frame_id'][0].item())) 186 | imageio.imwrite(depth_vis_path, (rendered_depth_vis*255.).astype(np.uint8)) 187 | 188 | 189 | mask = masks[b] == 1 190 | gt_rgb[b][mask==False] = 0. 191 | pred_rgb[b][mask==False] = 0. 192 | 193 | psnr_item = psnr(gt_rgb[b][mask], pred_rgb[b][mask], data_range=1.) 194 | if i == cfg.gefu.cas_config.num-1: 195 | self.psnrs.append(psnr_item) 196 | self.scene_psnrs[batch['meta']['scene'][b]+f'_level{i}'].append(psnr_item) 197 | 198 | ssim_item = ssim(gt_rgb[b], pred_rgb[b], multichannel=True) 199 | if i == cfg.gefu.cas_config.num-1: 200 | self.ssims.append(ssim_item) 201 | self.scene_ssims[batch['meta']['scene'][b]+f'_level{i}'].append(ssim_item) 202 | 203 | if cfg.eval_lpips: 204 | gt, pred = torch.Tensor(gt_rgb[b])[None].permute(0, 3, 1, 2), torch.Tensor(pred_rgb[b])[None].permute(0, 3, 1, 2) 205 | gt, pred = (gt-0.5)*2., (pred-0.5)*2. 206 | lpips_item = self.loss_fn_vgg(gt.cuda(), pred.cuda()).item() 207 | if i == cfg.gefu.cas_config.num-1: 208 | self.lpips.append(lpips_item) 209 | self.scene_lpips[batch['meta']['scene'][b]+f'_level{i}'].append(lpips_item) 210 | 211 | if cfg.gefu.eval_depth and (i == cfg.gefu.cas_config.num - 1) and batch['meta']['scene'][b] in self.eval_depth_scenes: 212 | nerf_depth = output['depth_level1'].cpu().numpy()[b].reshape((h, w)) 213 | mvs_depth = output['depth_mvs_level1'].cpu().numpy()[b] 214 | nerf_gt_depth = batch['tar_dpt'][b].cpu().numpy().reshape(*nerf_depth.shape) 215 | mvs_gt_depth = cv2.resize(nerf_gt_depth, mvs_depth.shape[::-1], interpolation=cv2.INTER_NEAREST) 216 | # nerf_mask = np.logical_and(nerf_gt_depth > 425., nerf_gt_depth < 905.) 217 | # mvs_mask = np.logical_and(mvs_gt_depth > 425., mvs_gt_depth < 905.) 218 | nerf_mask = nerf_gt_depth != 0. 219 | mvs_mask = mvs_gt_depth != 0. 220 | self.abs.append(np.abs(nerf_depth[nerf_mask] - nerf_gt_depth[nerf_mask]).mean()) 221 | self.acc_2.append((np.abs(nerf_depth[nerf_mask] - nerf_gt_depth[nerf_mask]) < 2).mean()) 222 | self.acc_10.append((np.abs(nerf_depth[nerf_mask] - nerf_gt_depth[nerf_mask]) < 10).mean()) 223 | self.mvs_abs.append((np.abs(mvs_depth[mvs_mask] - mvs_gt_depth[mvs_mask])).mean()) 224 | self.mvs_acc_2.append((np.abs(mvs_depth[mvs_mask] - mvs_gt_depth[mvs_mask]) < 2.).mean()) 225 | self.mvs_acc_10.append((np.abs(mvs_depth[mvs_mask] - mvs_gt_depth[mvs_mask]) < 10.).mean()) 226 | 227 | def summarize(self): 228 | ret = {} 229 | ret.update({'psnr': np.mean(self.psnrs)}) 230 | ret.update({'ssim': np.mean(self.ssims)}) 231 | if cfg.eval_lpips: 232 | ret.update({'lpips': np.mean(self.lpips)}) 233 | print('='*30) 234 | for scene in self.scene_psnrs: 235 | if cfg.eval_lpips: 236 | print(scene.ljust(16), 'psnr: {:.2f} ssim: {:.3f} lpips:{:.3f}'.format(np.mean(self.scene_psnrs[scene]), np.mean(self.scene_ssims[scene]), np.mean(self.scene_lpips[scene]))) 237 | else: 238 | print(scene.ljust(16), 'psnr: {:.2f} ssim: {:.3f} '.format(np.mean(self.scene_psnrs[scene]), np.mean(self.scene_ssims[scene]))) 239 | print('='*30) 240 | print(ret) 241 | if cfg.gefu.eval_depth: 242 | depth_ret = {} 243 | keys = ['abs', 'acc_2', 'acc_10'] 244 | for key in keys: 245 | depth_ret[key] = np.mean(getattr(self, key)) 246 | setattr(self, key, []) 247 | print(depth_ret) 248 | keys = ['mvs_abs', 'mvs_acc_2', 'mvs_acc_10'] 249 | depth_ret = {} 250 | for key in keys: 251 | depth_ret[key] = np.mean(getattr(self, key)) 252 | setattr(self, key, []) 253 | print(depth_ret) 254 | self.psnrs = [] 255 | self.ssims = [] 256 | self.lpips = [] 257 | self.scene_psnrs = {} 258 | self.scene_ssims = {} 259 | self.scene_lpips = {} 260 | if cfg.save_result: 261 | print('Save visualization results to: {}'.format(cfg.result_dir)) 262 | return ret 263 | -------------------------------------------------------------------------------- /lib/evaluators/make_evaluator.py: -------------------------------------------------------------------------------- 1 | import imp 2 | import os 3 | 4 | def _evaluator_factory(cfg): 5 | module = cfg.evaluator_module 6 | path = cfg.evaluator_path 7 | evaluator = imp.load_source(module, path).Evaluator() 8 | return evaluator 9 | 10 | 11 | def make_evaluator(cfg): 12 | if cfg.skip_eval: 13 | return None 14 | else: 15 | return _evaluator_factory(cfg) 16 | -------------------------------------------------------------------------------- /lib/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_network import make_network 2 | -------------------------------------------------------------------------------- /lib/networks/gefu/cost_reg_net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .utils import * 3 | 4 | class CostRegNet(nn.Module): 5 | def __init__(self, in_channels, norm_act=nn.BatchNorm3d): 6 | super(CostRegNet, self).__init__() 7 | self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act) 8 | 9 | self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act) 10 | self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act) 11 | 12 | self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act) 13 | self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act) 14 | 15 | self.conv5 = ConvBnReLU3D(32, 64, stride=2, norm_act=norm_act) 16 | self.conv6 = ConvBnReLU3D(64, 64, norm_act=norm_act) 17 | 18 | self.conv7 = nn.Sequential( 19 | nn.ConvTranspose3d(64, 32, 3, padding=1, output_padding=1, 20 | stride=2, bias=False), 21 | norm_act(32)) 22 | 23 | self.conv9 = nn.Sequential( 24 | nn.ConvTranspose3d(32, 16, 3, padding=1, output_padding=1, 25 | stride=2, bias=False), 26 | norm_act(16)) 27 | 28 | self.conv11 = nn.Sequential( 29 | nn.ConvTranspose3d(16, 8, 3, padding=1, output_padding=1, 30 | stride=2, bias=False), 31 | norm_act(8)) 32 | self.depth_conv = nn.Sequential(nn.Conv3d(8, 1, 3, padding=1, bias=False)) 33 | self.feat_conv = nn.Sequential(nn.Conv3d(8, 8, 3, padding=1, bias=False)) 34 | 35 | def forward(self, x): 36 | conv0 = self.conv0(x) 37 | conv2 = self.conv2(self.conv1(conv0)) 38 | conv4 = self.conv4(self.conv3(conv2)) 39 | x = self.conv6(self.conv5(conv4)) 40 | x = conv4 + self.conv7(x) 41 | del conv4 42 | x = conv2 + self.conv9(x) 43 | del conv2 44 | x = conv0 + self.conv11(x) 45 | del conv0 46 | feat = self.feat_conv(x) 47 | depth = self.depth_conv(x) 48 | return feat, depth.squeeze(1) 49 | 50 | 51 | class MinCostRegNet(nn.Module): 52 | def __init__(self, in_channels, norm_act=nn.BatchNorm3d): 53 | super(MinCostRegNet, self).__init__() 54 | self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act) 55 | 56 | self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act) 57 | self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act) 58 | 59 | self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act) 60 | self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act) 61 | 62 | self.conv9 = nn.Sequential( 63 | nn.ConvTranspose3d(32, 16, 3, padding=1, output_padding=1, 64 | stride=2, bias=False), 65 | norm_act(16)) 66 | 67 | self.conv11 = nn.Sequential( 68 | nn.ConvTranspose3d(16, 8, 3, padding=1, output_padding=1, 69 | stride=2, bias=False), 70 | norm_act(8)) 71 | 72 | self.depth_conv = nn.Sequential(nn.Conv3d(8, 1, 3, padding=1, bias=False)) 73 | self.feat_conv = nn.Sequential(nn.Conv3d(8, 8, 3, padding=1, bias=False)) 74 | 75 | def forward(self, x): 76 | conv0 = self.conv0(x) 77 | conv2 = self.conv2(self.conv1(conv0)) 78 | conv4 = self.conv4(self.conv3(conv2)) 79 | x = conv4 80 | x = conv2 + self.conv9(x) 81 | del conv2 82 | x = conv0 + self.conv11(x) 83 | del conv0 84 | feat = self.feat_conv(x) 85 | depth = self.depth_conv(x) 86 | return feat, depth.squeeze(1) 87 | 88 | 89 | 90 | class SigCostRegNet(nn.Module): 91 | def __init__(self, in_channels, norm_act=nn.BatchNorm3d): 92 | super(SigCostRegNet, self).__init__() 93 | self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act) 94 | 95 | self.conv1 = ConvBnReLU3D(8, 16, stride=(1,2,2), pad=(1,1,1),norm_act=norm_act) 96 | self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act) 97 | 98 | self.conv3 = ConvBnReLU3D(16, 32, stride=(1,2,2), pad=(1,1,1), norm_act=norm_act) 99 | self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act) 100 | 101 | self.conv5 = ConvBnReLU3D(32, 64, stride=(1,2,2), pad=(1,1,1), norm_act=norm_act) 102 | self.conv6 = ConvBnReLU3D(64, 64, norm_act=norm_act) 103 | 104 | self.conv7 = nn.Sequential( 105 | nn.ConvTranspose3d(64, 32, 3, padding=(1,1,1), output_padding=(0,1,1), 106 | stride=(1,2,2), bias=False), 107 | norm_act(32)) 108 | 109 | self.conv9 = nn.Sequential( 110 | nn.ConvTranspose3d(32, 16, 3, padding=(1,1,1), output_padding=(0,1,1), 111 | stride=(1,2,2), bias=False), 112 | norm_act(16)) 113 | 114 | self.conv11 = nn.Sequential( 115 | nn.ConvTranspose3d(16, 8, 3, padding=(1,1,1), output_padding=(0,1,1), 116 | stride=(1,2,2), bias=False), 117 | norm_act(8)) 118 | 119 | self.feat_conv = nn.Sequential(nn.Conv3d(8, in_channels, 3, padding=1, bias=False)) 120 | 121 | def forward(self, x): 122 | conv0 = self.conv0(x) 123 | conv2 = self.conv2(self.conv1(conv0)) 124 | conv4 = self.conv4(self.conv3(conv2)) 125 | x = self.conv6(self.conv5(conv4)) 126 | x = conv4 + self.conv7(x) 127 | del conv4 128 | x = conv2 + self.conv9(x) 129 | del conv2 130 | x = conv0 + self.conv11(x) 131 | del conv0 132 | feat = self.feat_conv(x) 133 | return feat -------------------------------------------------------------------------------- /lib/networks/gefu/feature_net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .utils import * 3 | 4 | class FeatureNet(nn.Module): 5 | def __init__(self, norm_act=nn.BatchNorm2d): 6 | super(FeatureNet, self).__init__() 7 | self.conv0 = nn.Sequential( 8 | ConvBnReLU(3, 8, 3, 1, 1, norm_act=norm_act), 9 | ConvBnReLU(8, 8, 3, 1, 1, norm_act=norm_act)) 10 | self.conv1 = nn.Sequential( 11 | ConvBnReLU(8, 16, 5, 2, 2, norm_act=norm_act), 12 | ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act)) 13 | self.conv2 = nn.Sequential( 14 | ConvBnReLU(16, 32, 5, 2, 2, norm_act=norm_act), 15 | ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act)) 16 | 17 | self.toplayer = nn.Conv2d(32, 32, 1) 18 | self.lat1 = nn.Conv2d(16, 32, 1) 19 | self.lat0 = nn.Conv2d(8, 32, 1) 20 | 21 | self.smooth1 = nn.Conv2d(32, 16, 3, padding=1) 22 | self.smooth0 = nn.Conv2d(32, 8, 3, padding=1) 23 | 24 | def _upsample_add(self, x, y): 25 | return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + y 26 | 27 | def forward(self, x): 28 | conv0 = self.conv0(x) 29 | conv1 = self.conv1(conv0) 30 | conv2 = self.conv2(conv1) 31 | feat2 = self.toplayer(conv2) 32 | feat1 = self._upsample_add(feat2, self.lat1(conv1)) 33 | feat0 = self._upsample_add(feat1, self.lat0(conv0)) 34 | feat1 = self.smooth1(feat1) 35 | feat0 = self.smooth0(feat0) 36 | return feat2, feat1, feat0 37 | 38 | 39 | class CNNRender(nn.Module): 40 | def __init__(self, norm_act=nn.BatchNorm2d): 41 | super(CNNRender, self).__init__() 42 | self.conv0 = ConvBnReLU(3, 8, 3, 1, 1, norm_act=norm_act) 43 | self.conv1 = ConvBnReLU(8, 16, 5, 2, 2, norm_act=norm_act) 44 | self.conv2 = nn.Conv2d(8, 16, 1) 45 | self.conv3 = nn.Conv2d(16, 3, 1) 46 | 47 | def _upsample_add(self, x, y): 48 | return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + y 49 | 50 | def forward(self, x): 51 | conv0 = self.conv0(x) 52 | conv1 = self.conv1(conv0) 53 | conv2 = self._upsample_add(conv1, self.conv2(conv0)) 54 | conv3 = self.conv3(conv2) 55 | return torch.clamp(conv3+x, 0., 1.) 56 | 57 | 58 | class AutoEncoder(nn.Module): 59 | def __init__(self, in_channels, hid_channels, out_channels, norm_act=nn.BatchNorm2d): 60 | super(AutoEncoder, self).__init__() 61 | self.conv0 = nn.Sequential( 62 | ConvBnReLU(in_channels,hid_channels, 3, 1, 1, norm_act=norm_act), 63 | ConvBnReLU(hid_channels, hid_channels, 3, 1, 1, norm_act=norm_act)) 64 | self.conv1 = nn.Sequential( 65 | ConvBnReLU(hid_channels, hid_channels*2, 5, 2, 2, norm_act=norm_act), 66 | ConvBnReLU(hid_channels*2, hid_channels*2, 3, 1, 1, norm_act=norm_act)) 67 | self.conv2 = nn.Sequential( 68 | ConvBnReLU(hid_channels*2, hid_channels*4, 5, 2, 2, norm_act=norm_act), 69 | ConvBnReLU(hid_channels*4, hid_channels*4, 3, 1, 1, norm_act=norm_act)) 70 | 71 | self.toplayer = nn.Conv2d(hid_channels*4, out_channels, 1) 72 | self.lat1 = nn.Conv2d(hid_channels*2, out_channels, 1) 73 | self.lat0 = nn.Conv2d(hid_channels, out_channels, 1) 74 | 75 | self.out = nn.Conv2d(out_channels, out_channels, 3, padding=1) 76 | 77 | def _upsample_add(self, x, y): 78 | return F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + y 79 | 80 | def forward(self, x): 81 | conv0 = self.conv0(x) 82 | conv1 = self.conv1(conv0) 83 | conv2 = self.conv2(conv1) 84 | feat2 = self.toplayer(conv2) 85 | feat1 = self._upsample_add(feat2, self.lat1(conv1)) 86 | feat0 = self._upsample_add(feat1, self.lat0(conv0)) 87 | out = self.out(feat0) 88 | return out 89 | -------------------------------------------------------------------------------- /lib/networks/gefu/nerf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from lib.config import cfg 5 | from .cost_reg_net import SigCostRegNet 6 | class ScaledDotProductAttention(nn.Module): 7 | ''' Scaled Dot-Product Attention ''' 8 | 9 | def __init__(self, temperature, attn_dropout=0.1): 10 | super().__init__() 11 | self.temperature = temperature 12 | # self.dropout = nn.Dropout(attn_dropout) 13 | 14 | def forward(self, q, k, v, mask=None): 15 | 16 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) 17 | 18 | if mask is not None: 19 | attn = attn.masked_fill(mask == 0, -1e9) 20 | # attn = attn * mask 21 | 22 | attn = F.softmax(attn, dim=-1) 23 | # attn = self.dropout(F.softmax(attn, dim=-1)) 24 | output = torch.matmul(attn, v) 25 | 26 | return output, attn 27 | 28 | class MultiHeadAttention(nn.Module): 29 | ''' Multi-Head Attention module ''' 30 | 31 | def __init__(self, n_head, d_model, d_model_k, d_k, d_v, dropout=0.1): 32 | super().__init__() 33 | 34 | self.n_head = n_head 35 | self.d_k = d_k 36 | self.d_v = d_v 37 | 38 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) 39 | self.w_ks = nn.Linear(d_model_k, n_head * d_k, bias=False) 40 | self.w_vs = nn.Linear(d_model_k, n_head * d_v, bias=False) 41 | self.fc = nn.Linear(n_head * d_v, d_model, bias=False) 42 | 43 | self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5) 44 | 45 | # self.dropout = nn.Dropout(dropout) 46 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 47 | 48 | def forward(self, q, k, v, mask=None): 49 | 50 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 51 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) 52 | 53 | residual = q 54 | 55 | # Pass through the pre-attention projection: b x lq x (n*dv) 56 | # Separate different heads: b x lq x n x dv 57 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 58 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 59 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 60 | 61 | # Transpose for attention dot product: b x n x lq x dv 62 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 63 | 64 | if mask is not None: 65 | mask = mask.unsqueeze(1) # For head axis broadcasting. 66 | 67 | q, attn = self.attention(q, k, v, mask=mask) 68 | 69 | # Transpose to move the head dimension back: b x lq x n x dv 70 | # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) 71 | q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) 72 | # q = self.dropout(self.fc(q)) 73 | q = self.fc(q) 74 | q += residual 75 | 76 | return q, attn 77 | 78 | class NeRF(nn.Module): 79 | def __init__(self, hid_n=64, feat_ch=16+3): 80 | """ 81 | """ 82 | super(NeRF, self).__init__() 83 | self.hid_n = hid_n 84 | self.agg = Agg(feat_ch) 85 | self.lr0 = nn.Sequential(nn.Linear(8+16, hid_n), 86 | nn.ReLU()) 87 | self.lrs = nn.ModuleList([ 88 | nn.Sequential(nn.Linear(hid_n, hid_n), nn.ReLU()) for i in range(0) 89 | ]) 90 | self.sigma = nn.Sequential(nn.Linear(hid_n, 1), nn.Softplus()) 91 | self.color = nn.Sequential( 92 | nn.Linear(64+24+feat_ch+4, hid_n), 93 | nn.ReLU(), 94 | nn.Linear(hid_n, 1), 95 | nn.ReLU()) 96 | self.lr0.apply(weights_init) 97 | self.lrs.apply(weights_init) 98 | self.sigma.apply(weights_init) 99 | self.color.apply(weights_init) 100 | self.regnet = SigCostRegNet(hid_n) 101 | self.attn = MultiHeadAttention(8,hid_n,feat_ch+4,4,4) 102 | 103 | 104 | def forward(self, vox_feat, img_feat_rgb_dir, size): 105 | H,W = size 106 | B, N_points, N_views = img_feat_rgb_dir.shape[:-1] 107 | S = img_feat_rgb_dir.shape[2] 108 | img_feat = self.agg(img_feat_rgb_dir) 109 | vox_img_feat = torch.cat((vox_feat, img_feat), dim=-1) 110 | x = self.lr0(vox_img_feat) 111 | x = x.reshape(B,H,W,-1,x.shape[-1]) 112 | x = x.permute(0,4,3,1,2) 113 | x = self.regnet(x) 114 | x = x.permute(0,1,3,4,2).flatten(2).permute(0,2,1) 115 | q = x.squeeze(0).unsqueeze(1) 116 | k = img_feat_rgb_dir.squeeze(0) 117 | x,_ = self.attn(q,k,k) 118 | x = x.squeeze(1).unsqueeze(0) 119 | for i in range(len(self.lrs)): 120 | x = self.lrs[i](x) 121 | sigma = self.sigma(x) 122 | x = torch.cat((x, vox_img_feat), dim=-1) 123 | fea_transmit = x.clone() 124 | x = x.view(B, -1, 1, x.shape[-1]).repeat(1, 1, S, 1) 125 | x = torch.cat((x, img_feat_rgb_dir), dim=-1) 126 | color_weight = F.softmax(self.color(x), dim=-2) 127 | color = torch.sum((img_feat_rgb_dir[..., -7:-4] * color_weight), dim=-2) 128 | return torch.cat([color, sigma], dim=-1), fea_transmit 129 | 130 | class Agg(nn.Module): 131 | def __init__(self, feat_ch): 132 | """ 133 | """ 134 | super(Agg, self).__init__() 135 | self.feat_ch = feat_ch 136 | if cfg.gefu.viewdir_agg: 137 | self.view_fc = nn.Sequential( 138 | nn.Linear(4, feat_ch), 139 | nn.ReLU(), 140 | ) 141 | self.view_fc.apply(weights_init) 142 | self.global_fc = nn.Sequential( 143 | nn.Linear(feat_ch*3, 32), 144 | nn.ReLU(), 145 | ) 146 | 147 | self.agg_w_fc = nn.Sequential( 148 | nn.Linear(32, 1), 149 | nn.ReLU(), 150 | ) 151 | self.fc = nn.Sequential( 152 | nn.Linear(32, 16), 153 | nn.ReLU(), 154 | ) 155 | self.global_fc.apply(weights_init) 156 | self.agg_w_fc.apply(weights_init) 157 | self.fc.apply(weights_init) 158 | 159 | def forward(self, img_feat_rgb_dir): 160 | B, S = len(img_feat_rgb_dir), img_feat_rgb_dir.shape[-2] 161 | if cfg.gefu.viewdir_agg: 162 | view_feat = self.view_fc(img_feat_rgb_dir[..., -4:]) 163 | img_feat_rgb = img_feat_rgb_dir[..., :-4] + view_feat 164 | else: 165 | img_feat_rgb = img_feat_rgb_dir[..., :-4] 166 | 167 | var_feat = torch.var(img_feat_rgb, dim=-2).view(B, -1, 1, self.feat_ch).repeat(1, 1, S, 1) 168 | avg_feat = torch.mean(img_feat_rgb, dim=-2).view(B, -1, 1, self.feat_ch).repeat(1, 1, S, 1) 169 | 170 | feat = torch.cat([img_feat_rgb, var_feat, avg_feat], dim=-1) 171 | global_feat = self.global_fc(feat) 172 | agg_w = F.softmax(self.agg_w_fc(global_feat), dim=-2) 173 | im_feat = (global_feat * agg_w).sum(dim=-2) 174 | return self.fc(im_feat) 175 | 176 | class MVSNeRF(nn.Module): 177 | def __init__(self, hid_n=64, feat_ch=16+3): 178 | """ 179 | """ 180 | super(MVSNeRF, self).__init__() 181 | self.hid_n = hid_n 182 | self.lr0 = nn.Sequential(nn.Linear(8+feat_ch*3, hid_n), 183 | nn.ReLU()) 184 | self.lrs = nn.ModuleList([ 185 | nn.Sequential(nn.Linear(hid_n, hid_n), nn.ReLU()) for i in range(0) 186 | ]) 187 | self.sigma = nn.Sequential(nn.Linear(hid_n, 1), nn.Softplus()) 188 | self.color = nn.Sequential( 189 | nn.Linear(hid_n, hid_n), 190 | nn.ReLU(), 191 | nn.Linear(hid_n, 3)) 192 | self.lr0.apply(weights_init) 193 | self.lrs.apply(weights_init) 194 | self.sigma.apply(weights_init) 195 | self.color.apply(weights_init) 196 | 197 | def forward(self, vox_feat, img_feat_rgb_dir): 198 | B, N_points, N_views = img_feat_rgb_dir.shape[:-1] 199 | # img_feat = self.agg(img_feat_rgb_dir) 200 | img_feat = torch.cat([img_feat_rgb_dir[..., i, :-4] for i in range(N_views)] , dim=-1) 201 | S = img_feat_rgb_dir.shape[2] 202 | vox_img_feat = torch.cat((vox_feat, img_feat), dim=-1) 203 | x = self.lr0(vox_img_feat) 204 | for i in range(len(self.lrs)): 205 | x = self.lrs[i](x) 206 | sigma = self.sigma(x) 207 | # x = torch.cat((x, vox_img_feat), dim=-1) 208 | # x = x.view(B, -1, 1, x.shape[-1]).repeat(1, 1, S, 1) 209 | # x = torch.cat((x, img_feat_rgb_dir), dim=-1) 210 | color = torch.sigmoid(self.color(x)) 211 | return torch.cat([color, sigma], dim=-1) 212 | 213 | 214 | 215 | def weights_init(m): 216 | if isinstance(m, nn.Linear): 217 | nn.init.kaiming_normal_(m.weight.data) 218 | if m.bias is not None: 219 | nn.init.zeros_(m.bias.data) 220 | 221 | -------------------------------------------------------------------------------- /lib/networks/gefu/res_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class ResidualConv(nn.Module): 5 | def __init__(self, input_dim, output_dim, stride, padding): 6 | super(ResidualConv, self).__init__() 7 | 8 | self.conv_block = nn.Sequential( 9 | nn.BatchNorm2d(input_dim), 10 | nn.ReLU(), 11 | nn.Conv2d( 12 | input_dim, output_dim, kernel_size=3, stride=stride, padding=padding 13 | ), 14 | nn.BatchNorm2d(output_dim), 15 | nn.ReLU(), 16 | nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), 17 | ) 18 | self.conv_skip = nn.Sequential( 19 | nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1), 20 | nn.BatchNorm2d(output_dim), 21 | ) 22 | 23 | def forward(self, x): 24 | 25 | return self.conv_block(x) + self.conv_skip(x) 26 | 27 | class Upsample(nn.Module): 28 | def __init__(self, input_dim, output_dim, kernel, stride): 29 | super(Upsample, self).__init__() 30 | 31 | self.upsample = nn.ConvTranspose2d( 32 | input_dim, output_dim, kernel_size=kernel, stride=stride 33 | ) 34 | 35 | def forward(self, x): 36 | return self.upsample(x) 37 | 38 | 39 | 40 | class ResUnet(nn.Module): 41 | def __init__(self, channel=3, filters=[16, 32, 64, 128]): 42 | super(ResUnet, self).__init__() 43 | 44 | self.input_layer = nn.Sequential( 45 | nn.Conv2d(channel, filters[0], kernel_size=3, padding=1), 46 | nn.BatchNorm2d(filters[0]), 47 | nn.ReLU(), 48 | nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), 49 | ) 50 | self.input_skip = nn.Sequential( 51 | nn.Conv2d(channel, filters[0], kernel_size=3, padding=1) 52 | ) 53 | 54 | self.residual_conv_1 = ResidualConv(filters[0], filters[1], 2, 1) 55 | self.residual_conv_2 = ResidualConv(filters[1], filters[2], 2, 1) 56 | 57 | self.bridge = ResidualConv(filters[2], filters[3], 2, 1) 58 | 59 | self.upsample_1 = Upsample(filters[3], filters[3], 2, 2) 60 | # self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], 1, 1) 61 | 62 | # self.upsample_2 = Upsample(filters[2], filters[2], 2, 2) 63 | # self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], 1, 1) 64 | 65 | # self.upsample_3 = Upsample(filters[1], filters[1], 2, 2) 66 | # self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], 1, 1) 67 | 68 | self.output_layer = nn.Sequential( 69 | nn.Conv2d(filters[2]+filters[3], 32, 1, 1) 70 | ) 71 | 72 | def forward(self, x): 73 | # Encode 74 | B, S, C, H, W = x.shape 75 | x = x.view(B*S, C, H, W) 76 | x1 = self.input_layer(x) + self.input_skip(x) 77 | x2 = self.residual_conv_1(x1) 78 | x3 = self.residual_conv_2(x2) 79 | # Bridge 80 | x4 = self.bridge(x3) 81 | # Decode 82 | x4 = self.upsample_1(x4) 83 | x5 = torch.cat([x4, x3], dim=1) 84 | 85 | # x6 = self.up_residual_conv1(x5) 86 | 87 | # x6 = self.upsample_2(x6) 88 | # x7 = torch.cat([x6, x2], dim=1) 89 | 90 | # x8 = self.up_residual_conv2(x7) 91 | 92 | # x8 = self.upsample_3(x8) 93 | # x9 = torch.cat([x8, x1], dim=1) 94 | 95 | # x10 = self.up_residual_conv3(x9) 96 | 97 | output = self.output_layer(x5) 98 | output = output.view(B, S, 32, H//4, W//4) 99 | return output 100 | -------------------------------------------------------------------------------- /lib/networks/make_network.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imp 3 | 4 | 5 | def make_network(cfg): 6 | module = cfg.network_module 7 | path = cfg.network_path 8 | network = imp.load_source(module, path).Network() 9 | return network 10 | -------------------------------------------------------------------------------- /lib/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainers import make_trainer 2 | from .optimizer import make_optimizer 3 | from .scheduler import make_lr_scheduler, set_lr_scheduler 4 | from .recorder import make_recorder 5 | 6 | -------------------------------------------------------------------------------- /lib/train/losses/gefu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from lib.config import cfg 5 | from lib.train.losses.vgg_perceptual_loss import VGGPerceptualLoss 6 | from lib.train.losses.ssim_loss import SSIM 7 | 8 | 9 | class NetworkWrapper(nn.Module): 10 | def __init__(self, net, train_loader): 11 | super(NetworkWrapper, self).__init__() 12 | self.device = torch.device('cuda:{}'.format(cfg.local_rank)) 13 | self.net = net 14 | self.color_crit = nn.MSELoss(reduction='mean') 15 | self.mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) 16 | self.perceptual_loss = VGGPerceptualLoss().to(self.device) 17 | 18 | def forward(self, batch): 19 | output = self.net(batch) 20 | 21 | scalar_stats = {} 22 | loss = 0 23 | for i in range(cfg.gefu.cas_config.num): 24 | color_loss = self.color_crit(batch[f'rgb_{i}'], output[f'rgb_level{i}']) 25 | scalar_stats.update({f'color_mse_{i}': color_loss}) 26 | loss += cfg.gefu.cas_config.loss_weight[i] * color_loss 27 | 28 | psnr = -10. * torch.log(color_loss) / torch.log(torch.Tensor([10.]).to(color_loss.device)) 29 | scalar_stats.update({f'psnr_{i}': psnr}) 30 | 31 | num_patchs = cfg.gefu.cas_config.num_patchs[i] 32 | if cfg.gefu.cas_config.train_img[i]: 33 | render_scale = cfg.gefu.cas_config.render_scale[i] 34 | B, S, C, H, W = batch['src_inps'].shape 35 | H, W = int(H * render_scale), int(W * render_scale) 36 | inp = output[f'rgb_level{i}'].reshape(B, H, W, 3).permute(0, 3, 1, 2) 37 | tar = batch[f'rgb_{i}'].reshape(B, H, W, 3).permute(0, 3, 1, 2) 38 | perceptual_loss = self.perceptual_loss(inp, tar) 39 | loss += 0.1 * perceptual_loss * cfg.gefu.cas_config.loss_weight[i] 40 | scalar_stats.update({f'perceptual_loss_{i}': perceptual_loss.detach()}) 41 | ssim = SSIM(window_size = 7) 42 | ssim_loss = 1-ssim(inp, tar) 43 | loss += 0.1 * ssim_loss * cfg.gefu.cas_config.loss_weight[i] 44 | scalar_stats.update({f'ssim_loss_{i}': ssim_loss.detach()}) 45 | 46 | elif num_patchs > 0: 47 | patch_size = cfg.gefu.cas_config.patch_size[i] 48 | num_rays = cfg.gefu.cas_config.num_rays[i] 49 | patch_rays = int(patch_size ** 2) 50 | inp = torch.empty((0, 3, patch_size, patch_size)).to(self.device) 51 | tar = torch.empty((0, 3, patch_size, patch_size)).to(self.device) 52 | for j in range(num_patchs): 53 | inp = torch.cat([inp, output[f'rgb_level{i}'][:, num_rays+j*patch_rays:num_rays+(j+1)*patch_rays, :].reshape(-1, patch_size, patch_size, 3).permute(0, 3, 1, 2)]) 54 | tar = torch.cat([tar, batch[f'rgb_{i}'][:, num_rays+j*patch_rays:num_rays+(j+1)*patch_rays, :].reshape(-1, patch_size, patch_size, 3).permute(0, 3, 1, 2)]) 55 | perceptual_loss = self.perceptual_loss(inp, tar) 56 | 57 | loss += 0.01 * perceptual_loss * cfg.gefu.cas_config.loss_weight[i] 58 | scalar_stats.update({f'perceptual_loss_{i}': perceptual_loss.detach()}) 59 | 60 | scalar_stats.update({'loss': loss}) 61 | image_stats = {} 62 | 63 | return output, loss, scalar_stats, image_stats 64 | 65 | -------------------------------------------------------------------------------- /lib/train/losses/nerf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from lib.utils import net_utils 4 | from lib.config import cfg 5 | 6 | class NetworkWrapper(nn.Module): 7 | def __init__(self, net): 8 | super(NetworkWrapper, self).__init__() 9 | self.net = net 10 | self.color_crit = nn.MSELoss(reduction='mean') 11 | self.mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) 12 | 13 | def forward(self, batch): 14 | output = self.net(batch) 15 | scalar_stats = {} 16 | loss = 0 17 | color_loss = self.color_crit(output['rgb_0'], batch['rgb']) 18 | scalar_stats.update({'color_mse_0': color_loss}) 19 | loss += color_loss 20 | 21 | psnr = -10. * torch.log(color_loss.detach()) / \ 22 | torch.log(torch.Tensor([10.]).to(color_loss.device)) 23 | scalar_stats.update({'psnr_0': psnr}) 24 | 25 | if len(cfg.task_arg.cascade_samples) > 1: 26 | color_loss = self.color_crit(output['rgb_1'], batch['rgb']) 27 | scalar_stats.update({'color_mse_1': color_loss}) 28 | loss += color_loss 29 | 30 | psnr = -10. * torch.log(color_loss.detach()) / \ 31 | torch.log(torch.Tensor([10.]).to(color_loss.device)) 32 | scalar_stats.update({'psnr_1': psnr}) 33 | 34 | scalar_stats.update({'loss': loss}) 35 | image_stats = {} 36 | 37 | return output, loss, scalar_stats, image_stats 38 | -------------------------------------------------------------------------------- /lib/train/losses/ssim_loss.py: -------------------------------------------------------------------------------- 1 | ##### the code is borrowed from https://github.com/Po-Hsun-Su/pytorch-ssim ##### 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | from math import exp 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 10 | return gauss/gauss.sum() 11 | 12 | def create_window(window_size, channel): 13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 15 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 16 | return window 17 | 18 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 19 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 20 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 21 | 22 | mu1_sq = mu1.pow(2) 23 | mu2_sq = mu2.pow(2) 24 | mu1_mu2 = mu1*mu2 25 | 26 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 27 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 28 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 29 | 30 | C1 = 0.01**2 31 | C2 = 0.03**2 32 | 33 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 34 | 35 | if size_average: 36 | return ssim_map.mean() 37 | else: 38 | return ssim_map.mean(1).mean(1).mean(1) 39 | 40 | class SSIM(torch.nn.Module): 41 | def __init__(self, window_size = 11, size_average = True): 42 | super(SSIM, self).__init__() 43 | self.window_size = window_size 44 | self.size_average = size_average 45 | self.channel = 1 46 | self.window = create_window(window_size, self.channel) 47 | 48 | def forward(self, img1, img2): 49 | (_, channel, _, _) = img1.size() 50 | 51 | if channel == self.channel and self.window.data.type() == img1.data.type(): 52 | window = self.window 53 | else: 54 | window = create_window(self.window_size, channel) 55 | 56 | if img1.is_cuda: 57 | window = window.cuda(img1.get_device()) 58 | window = window.type_as(img1) 59 | 60 | self.window = window 61 | self.channel = channel 62 | 63 | 64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 65 | 66 | def ssim(img1, img2, window_size = 11, size_average = True): 67 | (_, channel, _, _) = img1.size() 68 | window = create_window(window_size, channel) 69 | 70 | if img1.is_cuda: 71 | window = window.cuda(img1.get_device()) 72 | window = window.type_as(img1) 73 | 74 | return _ssim(img1, img2, window, window_size, channel, size_average) 75 | -------------------------------------------------------------------------------- /lib/train/losses/vgg_perceptual_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | class VGGPerceptualLoss(torch.nn.Module): 5 | def __init__(self, resize=False): 6 | super(VGGPerceptualLoss, self).__init__() 7 | blocks = [] 8 | blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval()) 9 | blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval()) 10 | blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval()) 11 | blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval()) 12 | for bl in blocks: 13 | for p in bl.parameters(): 14 | p.requires_grad = False 15 | self.blocks = torch.nn.ModuleList(blocks) 16 | self.transform = torch.nn.functional.interpolate 17 | self.resize = resize 18 | self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 19 | self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 20 | 21 | def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]): 22 | if input.shape[1] != 3: 23 | input = input.repeat(1, 3, 1, 1) 24 | target = target.repeat(1, 3, 1, 1) 25 | input = (input-self.mean) / self.std 26 | target = (target-self.mean) / self.std 27 | if self.resize: 28 | input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False) 29 | target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False) 30 | loss = 0.0 31 | x = input 32 | y = target 33 | for i, block in enumerate(self.blocks): 34 | x = block(x) 35 | y = block(y) 36 | if i in feature_layers: 37 | loss += torch.nn.functional.l1_loss(x, y) 38 | if i in style_layers: 39 | act_x = x.reshape(x.shape[0], x.shape[1], -1) 40 | act_y = y.reshape(y.shape[0], y.shape[1], -1) 41 | gram_x = act_x @ act_x.permute(0, 2, 1) 42 | gram_y = act_y @ act_y.permute(0, 2, 1) 43 | loss += torch.nn.functional.l1_loss(gram_x, gram_y) 44 | return loss 45 | -------------------------------------------------------------------------------- /lib/train/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from lib.utils.optimizer.radam import RAdam 3 | 4 | 5 | _optimizer_factory = { 6 | 'adam': torch.optim.Adam, 7 | 'radam': RAdam, 8 | 'sgd': torch.optim.SGD 9 | } 10 | 11 | 12 | def make_optimizer(cfg, net): 13 | params = [] 14 | lr = cfg.train.lr 15 | weight_decay = cfg.train.weight_decay 16 | eps = cfg.train.eps 17 | 18 | for key, value in net.named_parameters(): 19 | if not value.requires_grad: 20 | continue 21 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay, "eps": eps}] 22 | 23 | if 'adam' in cfg.train.optim: 24 | optimizer = _optimizer_factory[cfg.train.optim](params, lr, weight_decay=weight_decay, eps=eps) 25 | else: 26 | optimizer = _optimizer_factory[cfg.train.optim](params, lr, momentum=0.9) 27 | 28 | return optimizer 29 | -------------------------------------------------------------------------------- /lib/train/recorder.py: -------------------------------------------------------------------------------- 1 | from collections import deque, defaultdict 2 | import torch 3 | from tensorboardX import SummaryWriter 4 | import os 5 | from lib.config.config import cfg 6 | 7 | from termcolor import colored 8 | 9 | 10 | class SmoothedValue(object): 11 | """Track a series of values and provide access to smoothed values over a 12 | window or the global series average. 13 | """ 14 | 15 | def __init__(self, window_size=20): 16 | self.deque = deque(maxlen=window_size) 17 | self.total = 0.0 18 | self.count = 0 19 | 20 | def update(self, value): 21 | self.deque.append(value) 22 | self.count += 1 23 | self.total += value 24 | 25 | @property 26 | def median(self): 27 | d = torch.tensor(list(self.deque)) 28 | return d.median().item() 29 | 30 | @property 31 | def avg(self): 32 | d = torch.tensor(list(self.deque)) 33 | return d.mean().item() 34 | 35 | @property 36 | def global_avg(self): 37 | return self.total / self.count 38 | 39 | 40 | def process_volsdf(image_stats): 41 | for k, v in image_stats.items(): 42 | image_stats[k] = torch.clamp(v[0].permute(2, 0, 1), min=0., max=1.) 43 | return image_stats 44 | 45 | process_neus = process_volsdf 46 | 47 | class Recorder(object): 48 | def __init__(self, cfg): 49 | if cfg.local_rank > 0: 50 | return 51 | 52 | log_dir = cfg.record_dir 53 | if not cfg.resume: 54 | print(colored('remove contents of directory %s' % log_dir, 'red')) 55 | os.system('rm -r %s/*' % log_dir) 56 | self.writer = SummaryWriter(log_dir=log_dir) 57 | 58 | # scalars 59 | self.epoch = 0 60 | self.step = 0 61 | self.loss_stats = defaultdict(SmoothedValue) 62 | self.batch_time = SmoothedValue() 63 | self.data_time = SmoothedValue() 64 | 65 | # images 66 | self.image_stats = defaultdict(object) 67 | if 'process_' + cfg.task in globals(): 68 | self.processor = globals()['process_' + cfg.task] 69 | else: 70 | self.processor = None 71 | 72 | def update_loss_stats(self, loss_dict): 73 | if cfg.local_rank > 0: 74 | return 75 | for k, v in loss_dict.items(): 76 | self.loss_stats[k].update(v.detach().cpu()) 77 | 78 | def update_image_stats(self, image_stats): 79 | if cfg.local_rank > 0: 80 | return 81 | if self.processor is None: 82 | return 83 | image_stats = self.processor(image_stats) 84 | for k, v in image_stats.items(): 85 | self.image_stats[k] = v.detach().cpu() 86 | 87 | def record(self, prefix, step=-1, loss_stats=None, image_stats=None): 88 | if cfg.local_rank > 0: 89 | return 90 | 91 | pattern = prefix + '/{}' 92 | step = step if step >= 0 else self.step 93 | loss_stats = loss_stats if loss_stats else self.loss_stats 94 | 95 | for k, v in loss_stats.items(): 96 | if isinstance(v, SmoothedValue): 97 | self.writer.add_scalar(pattern.format(k), v.median, step) 98 | else: 99 | self.writer.add_scalar(pattern.format(k), v, step) 100 | 101 | if self.processor is None: 102 | return 103 | image_stats = self.processor(image_stats) if image_stats else self.image_stats 104 | for k, v in image_stats.items(): 105 | self.writer.add_image(pattern.format(k), v, step) 106 | 107 | def state_dict(self): 108 | if cfg.local_rank > 0: 109 | return 110 | scalar_dict = {} 111 | scalar_dict['step'] = self.step 112 | return scalar_dict 113 | 114 | def load_state_dict(self, scalar_dict): 115 | if cfg.local_rank > 0: 116 | return 117 | self.step = scalar_dict['step'] 118 | 119 | def __str__(self): 120 | if cfg.local_rank > 0: 121 | return 122 | loss_state = [] 123 | for k, v in self.loss_stats.items(): 124 | loss_state.append('{}: {:.4f}'.format(k, v.avg)) 125 | loss_state = ' '.join(loss_state) 126 | 127 | recording_state = ' '.join(['epoch: {}', 'step: {}', '{}', 'data: {:.4f}', 'batch: {:.4f}']) 128 | return recording_state.format(self.epoch, self.step, loss_state, self.data_time.avg, self.batch_time.avg) 129 | 130 | 131 | def make_recorder(cfg): 132 | return Recorder(cfg) 133 | -------------------------------------------------------------------------------- /lib/train/scheduler.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from lib.utils.optimizer.lr_scheduler import WarmupMultiStepLR, MultiStepLR, ExponentialLR 3 | 4 | 5 | def make_lr_scheduler(cfg, optimizer): 6 | cfg_scheduler = cfg.train.scheduler 7 | if cfg_scheduler.type == 'multi_step': 8 | scheduler = MultiStepLR(optimizer, 9 | milestones=cfg_scheduler.milestones, 10 | gamma=cfg_scheduler.gamma) 11 | elif cfg_scheduler.type == 'exponential': 12 | scheduler = ExponentialLR(optimizer, 13 | decay_epochs=cfg_scheduler.decay_epochs, 14 | gamma=cfg_scheduler.gamma) 15 | return scheduler 16 | 17 | 18 | def set_lr_scheduler(cfg, scheduler): 19 | cfg_scheduler = cfg.train.scheduler 20 | if cfg_scheduler.type == 'multi_step': 21 | scheduler.milestones = Counter(cfg_scheduler.milestones) 22 | elif cfg_scheduler.type == 'exponential': 23 | scheduler.decay_epochs = cfg_scheduler.decay_epochs 24 | scheduler.gamma = cfg_scheduler.gamma 25 | -------------------------------------------------------------------------------- /lib/train/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_trainer import make_trainer 2 | -------------------------------------------------------------------------------- /lib/train/trainers/make_trainer.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer 2 | import imp 3 | 4 | 5 | def _wrapper_factory(cfg, network, train_loader=None): 6 | module = cfg.loss_module 7 | path = cfg.loss_path 8 | network_wrapper = imp.load_source(module, path).NetworkWrapper(network, train_loader) 9 | return network_wrapper 10 | 11 | 12 | def make_trainer(cfg, network, train_loader=None): 13 | network = _wrapper_factory(cfg, network, train_loader) 14 | return Trainer(network) 15 | -------------------------------------------------------------------------------- /lib/train/trainers/trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import datetime 3 | import torch 4 | import tqdm 5 | from torch.nn import DataParallel 6 | from torch.nn.parallel import DistributedDataParallel 7 | from lib.config import cfg 8 | from lib.utils.data_utils import to_cuda 9 | 10 | 11 | class Trainer(object): 12 | def __init__(self, network): 13 | device = torch.device('cuda:{}'.format(cfg.local_rank)) 14 | network = network.to(device) 15 | if cfg.distributed: 16 | network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(network) 17 | network = DistributedDataParallel( 18 | network, 19 | device_ids=[cfg.local_rank], 20 | output_device=cfg.local_rank, 21 | find_unused_parameters=True 22 | ) 23 | self.network = network 24 | self.local_rank = cfg.local_rank 25 | self.device = device 26 | self.global_step = 0 27 | 28 | def reduce_loss_stats(self, loss_stats): 29 | reduced_losses = {k: torch.mean(v) for k, v in loss_stats.items()} 30 | return reduced_losses 31 | 32 | def to_cuda(self, batch): 33 | for k in batch: 34 | if isinstance(batch[k], tuple) or isinstance(batch[k], list): 35 | #batch[k] = [b.cuda() for b in batch[k]] 36 | batch[k] = [b.to(self.device) for b in batch[k]] 37 | elif isinstance(batch[k], dict): 38 | batch[k] = {key: self.to_cuda(batch[k][key]) for key in batch[k]} 39 | else: 40 | # batch[k] = batch[k].cuda() 41 | batch[k] = batch[k].to(self.device) 42 | return batch 43 | 44 | def train(self, epoch, data_loader, optimizer, recorder): 45 | max_iter = len(data_loader) 46 | self.network.train() 47 | end = time.time() 48 | if self.global_step == 0: 49 | self.global_step = cfg.ep_iter * epoch 50 | for iteration, batch in enumerate(data_loader): 51 | data_time = time.time() - end 52 | iteration = iteration + 1 53 | 54 | batch = to_cuda(batch, self.device) 55 | batch['step'] = 0 56 | output, loss, loss_stats, image_stats = self.network(batch) 57 | # training stage: loss; optimizer; scheduler 58 | loss = loss.mean() 59 | optimizer.zero_grad() 60 | loss.backward() 61 | torch.nn.utils.clip_grad_value_(self.network.parameters(), 40) 62 | optimizer.step() 63 | 64 | if cfg.local_rank > 0: 65 | continue 66 | 67 | # data recording stage: loss_stats, time, image_stats 68 | recorder.step += 1 69 | 70 | loss_stats = self.reduce_loss_stats(loss_stats) 71 | recorder.update_loss_stats(loss_stats) 72 | 73 | batch_time = time.time() - end 74 | end = time.time() 75 | recorder.batch_time.update(batch_time) 76 | recorder.data_time.update(data_time) 77 | 78 | self.global_step += 1 79 | if iteration % cfg.log_interval == 0 or iteration == (max_iter - 1): 80 | # print training state 81 | eta_seconds = recorder.batch_time.global_avg * (max_iter - iteration) 82 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 83 | lr = optimizer.param_groups[0]['lr'] 84 | memory = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 85 | 86 | training_state = ' '.join(['eta: {}', '{}', 'lr: {:.6f}', 'max_mem: {:.0f}']) 87 | training_state = training_state.format(eta_string, str(recorder), lr, memory) 88 | print(training_state) 89 | 90 | # record loss_stats and image_dict 91 | recorder.update_image_stats(image_stats) 92 | recorder.record('train') 93 | 94 | def val(self, epoch, data_loader, evaluator=None, recorder=None): 95 | self.network.eval() 96 | torch.cuda.empty_cache() 97 | val_loss_stats = {} 98 | image_stats = {} 99 | data_size = len(data_loader) 100 | for batch in tqdm.tqdm(data_loader): 101 | batch = to_cuda(batch, self.device) 102 | batch['step'] = recorder.step 103 | with torch.no_grad(): 104 | output, loss, loss_stats, _ = self.network(batch) 105 | if evaluator is not None: 106 | image_stats_ = evaluator.evaluate(output, batch) 107 | if image_stats_ is not None: 108 | image_stats.update(image_stats_) 109 | 110 | loss_stats = self.reduce_loss_stats(loss_stats) 111 | for k, v in loss_stats.items(): 112 | val_loss_stats.setdefault(k, 0) 113 | val_loss_stats[k] += v 114 | 115 | loss_state = [] 116 | for k in val_loss_stats.keys(): 117 | val_loss_stats[k] /= data_size 118 | loss_state.append('{}: {:.4f}'.format(k, val_loss_stats[k])) 119 | print(loss_state) 120 | 121 | if evaluator is not None: 122 | result = evaluator.summarize() 123 | val_loss_stats.update(result) 124 | 125 | if recorder: 126 | recorder.record('val', epoch, val_loss_stats, image_stats) 127 | return result -------------------------------------------------------------------------------- /lib/utils/base_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import numpy as np 4 | import cv2 5 | import time 6 | from termcolor import colored 7 | import importlib 8 | import torch.distributed as dist 9 | import math 10 | 11 | class perf_timer: 12 | def __init__(self, msg="Elapsed time: {}s", logf=lambda x: print(colored(x, 'yellow')), sync_cuda=True, use_ms=False, disabled=False): 13 | self.logf = logf 14 | self.msg = msg 15 | self.sync_cuda = sync_cuda 16 | self.use_ms = use_ms 17 | self.disabled = disabled 18 | 19 | self.loggedtime = None 20 | 21 | def __enter__(self,): 22 | if self.sync_cuda: 23 | torch.cuda.synchronize() 24 | self.start = time.perf_counter() 25 | 26 | def __exit__(self, exc_type, exc_value, traceback): 27 | if self.sync_cuda: 28 | torch.cuda.synchronize() 29 | self.logtime(self.msg) 30 | 31 | def logtime(self, msg=None, logf=None): 32 | if self.disabled: 33 | return 34 | # SAME CLASS, DIFFERENT FUNCTIONALITY, is this good? 35 | # call the logger for timing code sections 36 | if self.sync_cuda: 37 | torch.cuda.synchronize() 38 | 39 | # always remember current time 40 | prev = self.loggedtime 41 | self.loggedtime = time.perf_counter() 42 | 43 | # print it if we've remembered previous time 44 | if prev is not None and msg: 45 | logf = logf or self.logf 46 | diff = self.loggedtime-prev 47 | diff *= 1000 if self.use_ms else 1 48 | logf(msg.format(diff)) 49 | 50 | return self.loggedtime 51 | 52 | def read_pickle(pkl_path): 53 | with open(pkl_path, 'rb') as f: 54 | return pickle.load(f) 55 | 56 | 57 | def save_pickle(data, pkl_path): 58 | os.system('mkdir -p {}'.format(os.path.dirname(pkl_path))) 59 | with open(pkl_path, 'wb') as f: 60 | pickle.dump(data, f) 61 | 62 | 63 | def project(xyz, K, RT): 64 | """ 65 | xyz: [N, 3] 66 | K: [3, 3] 67 | RT: [3, 4] 68 | """ 69 | xyz = np.dot(xyz, RT[:, :3].T) + RT[:, 3:].T 70 | xyz = np.dot(xyz, K.T) 71 | xy = xyz[:, :2] / xyz[:, 2:] 72 | return xy 73 | 74 | def get_bbox_2d(bbox, K, RT): 75 | pts = np.array([[bbox[0, 0], bbox[0, 1], bbox[0, 2]], 76 | [bbox[0, 0], bbox[0, 1], bbox[1, 2]], 77 | [bbox[0, 0], bbox[1, 1], bbox[0, 2]], 78 | [bbox[0, 0], bbox[1, 1], bbox[1, 2]], 79 | [bbox[1, 0], bbox[0, 1], bbox[0, 2]], 80 | [bbox[1, 0], bbox[0, 1], bbox[1, 2]], 81 | [bbox[1, 0], bbox[1, 1], bbox[0, 2]], 82 | [bbox[1, 0], bbox[1, 1], bbox[1, 2]], 83 | ]) 84 | pts_2d = project(pts, K, RT) 85 | return [pts_2d[:, 0].min(), pts_2d[:, 1].min(), pts_2d[:, 0].max(), pts_2d[:, 1].max()] 86 | 87 | 88 | def get_bound_corners(bounds): 89 | min_x, min_y, min_z = bounds[0] 90 | max_x, max_y, max_z = bounds[1] 91 | corners_3d = np.array([ 92 | [min_x, min_y, min_z], 93 | [min_x, min_y, max_z], 94 | [min_x, max_y, min_z], 95 | [min_x, max_y, max_z], 96 | [max_x, min_y, min_z], 97 | [max_x, min_y, max_z], 98 | [max_x, max_y, min_z], 99 | [max_x, max_y, max_z], 100 | ]) 101 | return corners_3d 102 | 103 | def get_bound_2d_mask(bounds, K, pose, H, W): 104 | corners_3d = get_bound_corners(bounds) 105 | corners_2d = project(corners_3d, K, pose) 106 | corners_2d = np.round(corners_2d).astype(int) 107 | mask = np.zeros((H, W), dtype=np.uint8) 108 | cv2.fillPoly(mask, [corners_2d[[0, 1, 3, 2, 0]]], 1) 109 | cv2.fillPoly(mask, [corners_2d[[4, 5, 7, 6, 5]]], 1) 110 | cv2.fillPoly(mask, [corners_2d[[0, 1, 5, 4, 0]]], 1) 111 | cv2.fillPoly(mask, [corners_2d[[2, 3, 7, 6, 2]]], 1) 112 | cv2.fillPoly(mask, [corners_2d[[0, 2, 6, 4, 0]]], 1) 113 | cv2.fillPoly(mask, [corners_2d[[1, 3, 7, 5, 1]]], 1) 114 | return mask 115 | 116 | def load_object(module_name, module_args, **extra_args): 117 | module_path = '.'.join(module_name.split('.')[:-1]) 118 | module = importlib.import_module(module_path) 119 | name = module_name.split('.')[-1] 120 | obj = getattr(module, name)(**extra_args, **module_args) 121 | return obj 122 | 123 | 124 | 125 | def get_indices(length): 126 | num_replicas = dist.get_world_size() 127 | rank = dist.get_rank() 128 | num_samples = int(math.ceil(length * 1.0 / num_replicas)) 129 | total_size = num_samples * num_replicas 130 | indices = np.arange(length).tolist() 131 | indices += indices[: (total_size - len(indices))] 132 | offset = num_samples * rank 133 | indices = indices[offset:offset+num_samples] 134 | return indices 135 | 136 | 137 | class DotDict(dict): 138 | """ 139 | a dictionary that supports dot notation 140 | as well as dictionary access notation 141 | usage: d = DotDict() or d = DotDict({'val1':'first'}) 142 | set attributes: d.val2 = 'second' or d['val2'] = 'second' 143 | get attributes: d.val2 or d['val2'] 144 | """ 145 | __getattr__ = dict.__getitem__ 146 | __setattr__ = dict.__setitem__ 147 | __delattr__ = dict.__delitem__ 148 | 149 | def __init__(self, dct=None): 150 | if dct is not None: 151 | for key, value in dct.items(): 152 | if hasattr(value, 'keys'): 153 | value = DotDict(value) 154 | # self[key] = value 155 | 156 | 157 | 158 | -------------------------------------------------------------------------------- /lib/utils/data_config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | mean_rgb = np.array([0.485, 0.456, 0.406]).reshape(1, 1, 3).astype(np.float32) 3 | std_rgb = np.array([0.229, 0.224, 0.225]).reshape(1, 1, 3).astype(np.float32) 4 | -------------------------------------------------------------------------------- /lib/utils/img_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from matplotlib import cm 3 | import matplotlib.pyplot as plt 4 | import matplotlib.patches as patches 5 | import numpy as np 6 | import cv2 7 | from lib.utils import data_config 8 | 9 | 10 | 11 | def unnormalize_img(img, mean, std): 12 | """ 13 | img: [3, h, w] 14 | """ 15 | img = img.detach().cpu().clone() 16 | # img = img / 255. 17 | img *= torch.tensor(std).view(3, 1, 1) 18 | img += torch.tensor(mean).view(3, 1, 1) 19 | min_v = torch.min(img) 20 | img = (img - min_v) / (torch.max(img) - min_v) 21 | return img 22 | 23 | 24 | def bgr_to_rgb(img): 25 | return img[:, :, [2, 1, 0]] 26 | 27 | 28 | def horizon_concate(inp0, inp1): 29 | h0, w0 = inp0.shape[:2] 30 | h1, w1 = inp1.shape[:2] 31 | if inp0.ndim == 3: 32 | inp = np.zeros((max(h0, h1), w0 + w1, 3), dtype=inp0.dtype) 33 | inp[:h0, :w0, :] = inp0 34 | inp[:h1, w0:(w0 + w1), :] = inp1 35 | else: 36 | inp = np.zeros((max(h0, h1), w0 + w1), dtype=inp0.dtype) 37 | inp[:h0, :w0] = inp0 38 | inp[:h1, w0:(w0 + w1)] = inp1 39 | return inp 40 | 41 | 42 | def vertical_concate(inp0, inp1): 43 | h0, w0 = inp0.shape[:2] 44 | h1, w1 = inp1.shape[:2] 45 | if inp0.ndim == 3: 46 | inp = np.zeros((h0 + h1, max(w0, w1), 3), dtype=inp0.dtype) 47 | inp[:h0, :w0, :] = inp0 48 | inp[h0:(h0 + h1), :w1, :] = inp1 49 | else: 50 | inp = np.zeros((h0 + h1, max(w0, w1)), dtype=inp0.dtype) 51 | inp[:h0, :w0] = inp0 52 | inp[h0:(h0 + h1), :w1] = inp1 53 | return inp 54 | 55 | 56 | def transparent_cmap(cmap): 57 | """Copy colormap and set alpha values""" 58 | mycmap = cmap 59 | mycmap._init() 60 | mycmap._lut[:,-1] = 0.3 61 | return mycmap 62 | 63 | cmap = transparent_cmap(plt.get_cmap('jet')) 64 | 65 | 66 | def set_grid(ax, h, w, interval=8): 67 | ax.set_xticks(np.arange(0, w, interval)) 68 | ax.set_yticks(np.arange(0, h, interval)) 69 | ax.grid() 70 | ax.set_yticklabels([]) 71 | ax.set_xticklabels([]) 72 | 73 | 74 | color_list = np.array( 75 | [ 76 | 0.000, 0.447, 0.741, 77 | 0.850, 0.325, 0.098, 78 | 0.929, 0.694, 0.125, 79 | 0.494, 0.184, 0.556, 80 | 0.466, 0.674, 0.188, 81 | 0.301, 0.745, 0.933, 82 | 0.635, 0.078, 0.184, 83 | 0.300, 0.300, 0.300, 84 | 0.600, 0.600, 0.600, 85 | 1.000, 0.000, 0.000, 86 | 1.000, 0.500, 0.000, 87 | 0.749, 0.749, 0.000, 88 | 0.000, 1.000, 0.000, 89 | 0.000, 0.000, 1.000, 90 | 0.667, 0.000, 1.000, 91 | 0.333, 0.333, 0.000, 92 | 0.333, 0.667, 0.000, 93 | 0.333, 1.000, 0.000, 94 | 0.667, 0.333, 0.000, 95 | 0.667, 0.667, 0.000, 96 | 0.667, 1.000, 0.000, 97 | 1.000, 0.333, 0.000, 98 | 1.000, 0.667, 0.000, 99 | 1.000, 1.000, 0.000, 100 | 0.000, 0.333, 0.500, 101 | 0.000, 0.667, 0.500, 102 | 0.000, 1.000, 0.500, 103 | 0.333, 0.000, 0.500, 104 | 0.333, 0.333, 0.500, 105 | 0.333, 0.667, 0.500, 106 | 0.333, 1.000, 0.500, 107 | 0.667, 0.000, 0.500, 108 | 0.667, 0.333, 0.500, 109 | 0.667, 0.667, 0.500, 110 | 0.667, 1.000, 0.500, 111 | 1.000, 0.000, 0.500, 112 | 1.000, 0.333, 0.500, 113 | 1.000, 0.667, 0.500, 114 | 1.000, 1.000, 0.500, 115 | 0.000, 0.333, 1.000, 116 | 0.000, 0.667, 1.000, 117 | 0.000, 1.000, 1.000, 118 | 0.333, 0.000, 1.000, 119 | 0.333, 0.333, 1.000, 120 | 0.333, 0.667, 1.000, 121 | 0.333, 1.000, 1.000, 122 | 0.667, 0.000, 1.000, 123 | 0.667, 0.333, 1.000, 124 | 0.667, 0.667, 1.000, 125 | 0.667, 1.000, 1.000, 126 | 1.000, 0.000, 1.000, 127 | 1.000, 0.333, 1.000, 128 | 1.000, 0.667, 1.000, 129 | 0.167, 0.000, 0.000, 130 | 0.333, 0.000, 0.000, 131 | 0.500, 0.000, 0.000, 132 | 0.667, 0.000, 0.000, 133 | 0.833, 0.000, 0.000, 134 | 1.000, 0.000, 0.000, 135 | 0.000, 0.167, 0.000, 136 | 0.000, 0.333, 0.000, 137 | 0.000, 0.500, 0.000, 138 | 0.000, 0.667, 0.000, 139 | 0.000, 0.833, 0.000, 140 | 0.000, 1.000, 0.000, 141 | 0.000, 0.000, 0.167, 142 | 0.000, 0.000, 0.333, 143 | 0.000, 0.000, 0.500, 144 | 0.000, 0.000, 0.667, 145 | 0.000, 0.000, 0.833, 146 | 0.000, 0.000, 1.000, 147 | 0.000, 0.000, 0.000, 148 | 0.143, 0.143, 0.143, 149 | 0.286, 0.286, 0.286, 150 | 0.429, 0.429, 0.429, 151 | 0.571, 0.571, 0.571, 152 | 0.714, 0.714, 0.714, 153 | 0.857, 0.857, 0.857, 154 | 1.000, 1.000, 1.000, 155 | 0.50, 0.5, 0 156 | ] 157 | ).astype(np.float32) 158 | colors = color_list.reshape((-1, 3)) * 255 159 | colors = np.array(colors, dtype=np.uint8).reshape(len(colors), 1, 1, 3) 160 | 161 | def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET): 162 | """ 163 | depth: (H, W) 164 | """ 165 | x = np.nan_to_num(depth) # change nan to 0 166 | if minmax is None: 167 | mi = np.min(x[x>0]) # get minimum positive depth (ignore background) 168 | ma = np.max(x) 169 | else: 170 | mi,ma = minmax 171 | 172 | x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1 173 | x = (255*x).astype(np.uint8) 174 | x_ = cv2.applyColorMap(x, cmap) 175 | return x_, [mi,ma] 176 | -------------------------------------------------------------------------------- /lib/utils/mask_utils.py: -------------------------------------------------------------------------------- 1 | def get_class_ids_from_labels(labels): 2 | ids = [] 3 | for l in labels: 4 | ids.append(label_id_mapping_ade20k[l]) 5 | return ids 6 | 7 | def get_label_id_mapping(use_human_mask=False): 8 | if use_human_mask: 9 | return label_id_mapping_human 10 | else: 11 | return label_id_mapping_ade20k 12 | 13 | id_label_mapping_human = { 14 | 0: 'non_person', 15 | 1: 'person' 16 | } 17 | 18 | label_id_mapping_human = { 19 | 'non_person': 0, 20 | 'person': 1 21 | } 22 | 23 | id_label_mapping_ade20k = {0: 'wall', 24 | 1: 'building', 25 | 2: 'sky', 26 | 3: 'floor', 27 | 4: 'tree', 28 | 5: 'ceiling', 29 | 6: 'road', 30 | 7: 'bed ', 31 | 8: 'windowpane', 32 | 9: 'grass', 33 | 10: 'cabinet', 34 | 11: 'sidewalk', 35 | 12: 'person', 36 | 13: 'earth', 37 | 14: 'door', 38 | 15: 'table', 39 | 16: 'mountain', 40 | 17: 'plant', 41 | 18: 'curtain', 42 | 19: 'chair', 43 | 20: 'car', 44 | 21: 'water', 45 | 22: 'painting', 46 | 23: 'sofa', 47 | 24: 'shelf', 48 | 25: 'house', 49 | 26: 'sea', 50 | 27: 'mirror', 51 | 28: 'rug', 52 | 29: 'field', 53 | 30: 'armchair', 54 | 31: 'seat', 55 | 32: 'fence', 56 | 33: 'desk', 57 | 34: 'rock', 58 | 35: 'wardrobe', 59 | 36: 'lamp', 60 | 37: 'bathtub', 61 | 38: 'railing', 62 | 39: 'cushion', 63 | 40: 'base', 64 | 41: 'box', 65 | 42: 'column', 66 | 43: 'signboard', 67 | 44: 'chest of drawers', 68 | 45: 'counter', 69 | 46: 'sand', 70 | 47: 'sink', 71 | 48: 'skyscraper', 72 | 49: 'fireplace', 73 | 50: 'refrigerator', 74 | 51: 'grandstand', 75 | 52: 'path', 76 | 53: 'stairs', 77 | 54: 'runway', 78 | 55: 'case', 79 | 56: 'pool table', 80 | 57: 'pillow', 81 | 58: 'screen door', 82 | 59: 'stairway', 83 | 60: 'river', 84 | 61: 'bridge', 85 | 62: 'bookcase', 86 | 63: 'blind', 87 | 64: 'coffee table', 88 | 65: 'toilet', 89 | 66: 'flower', 90 | 67: 'book', 91 | 68: 'hill', 92 | 69: 'bench', 93 | 70: 'countertop', 94 | 71: 'stove', 95 | 72: 'palm', 96 | 73: 'kitchen island', 97 | 74: 'computer', 98 | 75: 'swivel chair', 99 | 76: 'boat', 100 | 77: 'bar', 101 | 78: 'arcade machine', 102 | 79: 'hovel', 103 | 80: 'bus', 104 | 81: 'towel', 105 | 82: 'light', 106 | 83: 'truck', 107 | 84: 'tower', 108 | 85: 'chandelier', 109 | 86: 'awning', 110 | 87: 'streetlight', 111 | 88: 'booth', 112 | 89: 'television receiver', 113 | 90: 'airplane', 114 | 91: 'dirt track', 115 | 92: 'apparel', 116 | 93: 'pole', 117 | 94: 'land', 118 | 95: 'bannister', 119 | 96: 'escalator', 120 | 97: 'ottoman', 121 | 98: 'bottle', 122 | 99: 'buffet', 123 | 100: 'poster', 124 | 101: 'stage', 125 | 102: 'van', 126 | 103: 'ship', 127 | 104: 'fountain', 128 | 105: 'conveyer belt', 129 | 106: 'canopy', 130 | 107: 'washer', 131 | 108: 'plaything', 132 | 109: 'swimming pool', 133 | 110: 'stool', 134 | 111: 'barrel', 135 | 112: 'basket', 136 | 113: 'waterfall', 137 | 114: 'tent', 138 | 115: 'bag', 139 | 116: 'minibike', 140 | 117: 'cradle', 141 | 118: 'oven', 142 | 119: 'ball', 143 | 120: 'food', 144 | 121: 'step', 145 | 122: 'tank', 146 | 123: 'trade name', 147 | 124: 'microwave', 148 | 125: 'pot', 149 | 126: 'animal', 150 | 127: 'bicycle', 151 | 128: 'lake', 152 | 129: 'dishwasher', 153 | 130: 'screen', 154 | 131: 'blanket', 155 | 132: 'sculpture', 156 | 133: 'hood', 157 | 134: 'sconce', 158 | 135: 'vase', 159 | 136: 'traffic light', 160 | 137: 'tray', 161 | 138: 'ashcan', 162 | 139: 'fan', 163 | 140: 'pier', 164 | 141: 'crt screen', 165 | 142: 'plate', 166 | 143: 'monitor', 167 | 144: 'bulletin board', 168 | 145: 'shower', 169 | 146: 'radiator', 170 | 147: 'glass', 171 | 148: 'clock', 172 | 149: 'flag'} 173 | 174 | label_id_mapping_ade20k = {'airplane': 90, 175 | 'animal': 126, 176 | 'apparel': 92, 177 | 'arcade machine': 78, 178 | 'armchair': 30, 179 | 'ashcan': 138, 180 | 'awning': 86, 181 | 'bag': 115, 182 | 'ball': 119, 183 | 'bannister': 95, 184 | 'bar': 77, 185 | 'barrel': 111, 186 | 'base': 40, 187 | 'basket': 112, 188 | 'bathtub': 37, 189 | 'bed ': 7, 190 | 'bench': 69, 191 | 'bicycle': 127, 192 | 'blanket': 131, 193 | 'blind': 63, 194 | 'boat': 76, 195 | 'book': 67, 196 | 'bookcase': 62, 197 | 'booth': 88, 198 | 'bottle': 98, 199 | 'box': 41, 200 | 'bridge': 61, 201 | 'buffet': 99, 202 | 'building': 1, 203 | 'bulletin board': 144, 204 | 'bus': 80, 205 | 'cabinet': 10, 206 | 'canopy': 106, 207 | 'car': 20, 208 | 'case': 55, 209 | 'ceiling': 5, 210 | 'chair': 19, 211 | 'chandelier': 85, 212 | 'chest of drawers': 44, 213 | 'clock': 148, 214 | 'coffee table': 64, 215 | 'column': 42, 216 | 'computer': 74, 217 | 'conveyer belt': 105, 218 | 'counter': 45, 219 | 'countertop': 70, 220 | 'cradle': 117, 221 | 'crt screen': 141, 222 | 'curtain': 18, 223 | 'cushion': 39, 224 | 'desk': 33, 225 | 'dirt track': 91, 226 | 'dishwasher': 129, 227 | 'door': 14, 228 | 'earth': 13, 229 | 'escalator': 96, 230 | 'fan': 139, 231 | 'fence': 32, 232 | 'field': 29, 233 | 'fireplace': 49, 234 | 'flag': 149, 235 | 'floor': 3, 236 | 'flower': 66, 237 | 'food': 120, 238 | 'fountain': 104, 239 | 'glass': 147, 240 | 'grandstand': 51, 241 | 'grass': 9, 242 | 'hill': 68, 243 | 'hood': 133, 244 | 'house': 25, 245 | 'hovel': 79, 246 | 'kitchen island': 73, 247 | 'lake': 128, 248 | 'lamp': 36, 249 | 'land': 94, 250 | 'light': 82, 251 | 'microwave': 124, 252 | 'minibike': 116, 253 | 'mirror': 27, 254 | 'monitor': 143, 255 | 'mountain': 16, 256 | 'ottoman': 97, 257 | 'oven': 118, 258 | 'painting': 22, 259 | 'palm': 72, 260 | 'path': 52, 261 | 'person': 12, 262 | 'pier': 140, 263 | 'pillow': 57, 264 | 'plant': 17, 265 | 'plate': 142, 266 | 'plaything': 108, 267 | 'pole': 93, 268 | 'pool table': 56, 269 | 'poster': 100, 270 | 'pot': 125, 271 | 'radiator': 146, 272 | 'railing': 38, 273 | 'refrigerator': 50, 274 | 'river': 60, 275 | 'road': 6, 276 | 'rock': 34, 277 | 'rug': 28, 278 | 'runway': 54, 279 | 'sand': 46, 280 | 'sconce': 134, 281 | 'screen': 130, 282 | 'screen door': 58, 283 | 'sculpture': 132, 284 | 'sea': 26, 285 | 'seat': 31, 286 | 'shelf': 24, 287 | 'ship': 103, 288 | 'shower': 145, 289 | 'sidewalk': 11, 290 | 'signboard': 43, 291 | 'sink': 47, 292 | 'sky': 2, 293 | 'skyscraper': 48, 294 | 'sofa': 23, 295 | 'stage': 101, 296 | 'stairs': 53, 297 | 'stairway': 59, 298 | 'step': 121, 299 | 'stool': 110, 300 | 'stove': 71, 301 | 'streetlight': 87, 302 | 'swimming pool': 109, 303 | 'swivel chair': 75, 304 | 'table': 15, 305 | 'tank': 122, 306 | 'television receiver': 89, 307 | 'tent': 114, 308 | 'toilet': 65, 309 | 'towel': 81, 310 | 'tower': 84, 311 | 'trade name': 123, 312 | 'traffic light': 136, 313 | 'tray': 137, 314 | 'tree': 4, 315 | 'truck': 83, 316 | 'van': 102, 317 | 'vase': 135, 318 | 'wall': 0, 319 | 'wardrobe': 35, 320 | 'washer': 107, 321 | 'water': 21, 322 | 'waterfall': 113, 323 | 'windowpane': 8} 324 | -------------------------------------------------------------------------------- /lib/utils/net_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from torch import nn 4 | import numpy as np 5 | import torch.nn.functional 6 | from collections import OrderedDict 7 | from termcolor import colored 8 | import sys 9 | import yaml 10 | from lib.config import cfg 11 | 12 | 13 | def gen_rays_bbox(rays, bounds): 14 | rays_o, rays_d = rays[..., :3], rays[..., 3:6] 15 | norm_d = torch.norm(rays_d, dim=-1, keepdim=True) 16 | viewdir = rays_d / norm_d 17 | viewdir[(viewdir < 1e-5) & (viewdir > -1e-10)] = 1e-5 18 | viewdir[(viewdir > -1e-5) & (viewdir < 1e-10)] = -1e-5 19 | 20 | tmin = (bounds[0:1] - rays_o[:1]) / viewdir 21 | tmax = (bounds[1:2] - rays_o[:1]) / viewdir 22 | t1 = torch.min(tmin, tmax) 23 | t2 = torch.max(tmin, tmax) 24 | 25 | near = torch.max(t1, dim=-1)[0] 26 | far = torch.min(t2, dim=-1)[0] 27 | mask_at_box = near < far 28 | return mask_at_box 29 | 30 | import time 31 | class perf_timer: 32 | def __init__(self, msg="Elapsed time: {}s", logf=lambda x: print(colored(x, 'yellow')), sync_cuda=True, use_ms=False, disabled=False): 33 | self.logf = logf 34 | self.msg = msg 35 | self.sync_cuda = sync_cuda 36 | self.use_ms = use_ms 37 | self.disabled = disabled 38 | 39 | self.loggedtime = None 40 | 41 | def __enter__(self,): 42 | if self.sync_cuda: 43 | torch.cuda.synchronize() 44 | self.start = time.perf_counter() 45 | 46 | def __exit__(self, exc_type, exc_value, traceback): 47 | if self.sync_cuda: 48 | torch.cuda.synchronize() 49 | self.logtime(self.msg) 50 | 51 | def logtime(self, msg=None, logf=None): 52 | if self.disabled: 53 | return 54 | # SAME CLASS, DIFFERENT FUNCTIONALITY, is this good? 55 | # call the logger for timing code sections 56 | if self.sync_cuda: 57 | torch.cuda.synchronize() 58 | 59 | # always remember current time 60 | prev = self.loggedtime 61 | self.loggedtime = time.perf_counter() 62 | 63 | # print it if we've remembered previous time 64 | if prev is not None and msg: 65 | logf = logf or self.logf 66 | diff = self.loggedtime-prev 67 | diff *= 1000 if self.use_ms else 1 68 | logf(msg.format(diff)) 69 | 70 | return self.loggedtime 71 | 72 | def sigmoid(x): 73 | y = torch.clamp(x.sigmoid(), min=1e-4, max=1 - 1e-4) 74 | return y 75 | 76 | 77 | def _neg_loss(pred, gt): 78 | ''' Modified focal loss. Exactly the same as CornerNet. 79 | Runs faster and costs a little bit more memory 80 | Arguments: 81 | pred (batch x c x h x w) 82 | gt_regr (batch x c x h x w) 83 | ''' 84 | pos_inds = gt.eq(1).float() 85 | neg_inds = gt.lt(1).float() 86 | 87 | neg_weights = torch.pow(1 - gt, 4) 88 | 89 | loss = 0 90 | 91 | pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds 92 | neg_loss = torch.log(1 - pred) * torch.pow(pred, 93 | 2) * neg_weights * neg_inds 94 | 95 | num_pos = pos_inds.float().sum() 96 | pos_loss = pos_loss.sum() 97 | neg_loss = neg_loss.sum() 98 | 99 | if num_pos == 0: 100 | loss = loss - neg_loss 101 | else: 102 | loss = loss - (pos_loss + neg_loss) / num_pos 103 | return loss 104 | 105 | 106 | class FocalLoss(nn.Module): 107 | '''nn.Module warpper for focal loss''' 108 | def __init__(self): 109 | super(FocalLoss, self).__init__() 110 | self.neg_loss = _neg_loss 111 | 112 | def forward(self, out, target): 113 | return self.neg_loss(out, target) 114 | 115 | 116 | def smooth_l1_loss(vertex_pred, 117 | vertex_targets, 118 | vertex_weights, 119 | sigma=1.0, 120 | normalize=True, 121 | reduce=True): 122 | """ 123 | :param vertex_pred: [b, vn*2, h, w] 124 | :param vertex_targets: [b, vn*2, h, w] 125 | :param vertex_weights: [b, 1, h, w] 126 | :param sigma: 127 | :param normalize: 128 | :param reduce: 129 | :return: 130 | """ 131 | b, ver_dim, _, _ = vertex_pred.shape 132 | sigma_2 = sigma**2 133 | vertex_diff = vertex_pred - vertex_targets 134 | diff = vertex_weights * vertex_diff 135 | abs_diff = torch.abs(diff) 136 | smoothL1_sign = (abs_diff < 1. / sigma_2).detach().float() 137 | in_loss = torch.pow(diff, 2) * (sigma_2 / 2.) * smoothL1_sign \ 138 | + (abs_diff - (0.5 / sigma_2)) * (1. - smoothL1_sign) 139 | 140 | if normalize: 141 | in_loss = torch.sum(in_loss.view(b, -1), 1) / ( 142 | ver_dim * torch.sum(vertex_weights.view(b, -1), 1) + 1e-3) 143 | 144 | if reduce: 145 | in_loss = torch.mean(in_loss) 146 | 147 | return in_loss 148 | 149 | 150 | class SmoothL1Loss(nn.Module): 151 | def __init__(self): 152 | super(SmoothL1Loss, self).__init__() 153 | self.smooth_l1_loss = smooth_l1_loss 154 | 155 | def forward(self, 156 | preds, 157 | targets, 158 | weights, 159 | sigma=1.0, 160 | normalize=True, 161 | reduce=True): 162 | return self.smooth_l1_loss(preds, targets, weights, sigma, normalize, 163 | reduce) 164 | 165 | 166 | class AELoss(nn.Module): 167 | def __init__(self): 168 | super(AELoss, self).__init__() 169 | 170 | def forward(self, ae, ind, ind_mask): 171 | """ 172 | ae: [b, 1, h, w] 173 | ind: [b, max_objs, max_parts] 174 | ind_mask: [b, max_objs, max_parts] 175 | obj_mask: [b, max_objs] 176 | """ 177 | # first index 178 | b, _, h, w = ae.shape 179 | b, max_objs, max_parts = ind.shape 180 | obj_mask = torch.sum(ind_mask, dim=2) != 0 181 | 182 | ae = ae.view(b, h * w, 1) 183 | seed_ind = ind.view(b, max_objs * max_parts, 1) 184 | tag = ae.gather(1, seed_ind).view(b, max_objs, max_parts) 185 | 186 | # compute the mean 187 | tag_mean = tag * ind_mask 188 | tag_mean = tag_mean.sum(2) / (ind_mask.sum(2) + 1e-4) 189 | 190 | # pull ae of the same object to their mean 191 | pull_dist = (tag - tag_mean.unsqueeze(2)).pow(2) * ind_mask 192 | obj_num = obj_mask.sum(dim=1).float() 193 | pull = (pull_dist.sum(dim=(1, 2)) / (obj_num + 1e-4)).sum() 194 | pull /= b 195 | 196 | # push away the mean of different objects 197 | push_dist = torch.abs(tag_mean.unsqueeze(1) - tag_mean.unsqueeze(2)) 198 | push_dist = 1 - push_dist 199 | push_dist = nn.functional.relu(push_dist, inplace=True) 200 | obj_mask = (obj_mask.unsqueeze(1) + obj_mask.unsqueeze(2)) == 2 201 | push_dist = push_dist * obj_mask.float() 202 | push = ((push_dist.sum(dim=(1, 2)) - obj_num) / 203 | (obj_num * (obj_num - 1) + 1e-4)).sum() 204 | push /= b 205 | return pull, push 206 | 207 | 208 | class PolyMatchingLoss(nn.Module): 209 | def __init__(self, pnum): 210 | super(PolyMatchingLoss, self).__init__() 211 | 212 | self.pnum = pnum 213 | batch_size = 1 214 | pidxall = np.zeros(shape=(batch_size, pnum, pnum), dtype=np.int32) 215 | for b in range(batch_size): 216 | for i in range(pnum): 217 | pidx = (np.arange(pnum) + i) % pnum 218 | pidxall[b, i] = pidx 219 | 220 | device = torch.device('cuda') 221 | pidxall = torch.from_numpy( 222 | np.reshape(pidxall, newshape=(batch_size, -1))).to(device) 223 | 224 | self.feature_id = pidxall.unsqueeze_(2).long().expand( 225 | pidxall.size(0), pidxall.size(1), 2).detach() 226 | 227 | def forward(self, pred, gt, loss_type="L2"): 228 | pnum = self.pnum 229 | batch_size = pred.size()[0] 230 | feature_id = self.feature_id.expand(batch_size, 231 | self.feature_id.size(1), 2) 232 | device = torch.device('cuda') 233 | 234 | gt_expand = torch.gather(gt, 1, 235 | feature_id).view(batch_size, pnum, pnum, 2) 236 | 237 | pred_expand = pred.unsqueeze(1) 238 | 239 | dis = pred_expand - gt_expand 240 | 241 | if loss_type == "L2": 242 | dis = (dis**2).sum(3).sqrt().sum(2) 243 | elif loss_type == "L1": 244 | dis = torch.abs(dis).sum(3).sum(2) 245 | 246 | min_dis, min_id = torch.min(dis, dim=1, keepdim=True) 247 | # print(min_id) 248 | 249 | # min_id = torch.from_numpy(min_id.data.cpu().numpy()).to(device) 250 | # min_gt_id_to_gather = min_id.unsqueeze_(2).unsqueeze_(3).long().\ 251 | # expand(min_id.size(0), min_id.size(1), gt_expand.size(2), gt_expand.size(3)) 252 | # gt_right_order = torch.gather(gt_expand, 1, min_gt_id_to_gather).view(batch_size, pnum, 2) 253 | 254 | return torch.mean(min_dis) 255 | 256 | 257 | class AttentionLoss(nn.Module): 258 | def __init__(self, beta=4, gamma=0.5): 259 | super(AttentionLoss, self).__init__() 260 | 261 | self.beta = beta 262 | self.gamma = gamma 263 | 264 | def forward(self, pred, gt): 265 | num_pos = torch.sum(gt) 266 | num_neg = torch.sum(1 - gt) 267 | alpha = num_neg / (num_pos + num_neg) 268 | edge_beta = torch.pow(self.beta, torch.pow(1 - pred, self.gamma)) 269 | bg_beta = torch.pow(self.beta, torch.pow(pred, self.gamma)) 270 | 271 | loss = 0 272 | loss = loss - alpha * edge_beta * torch.log(pred) * gt 273 | loss = loss - (1 - alpha) * bg_beta * torch.log(1 - pred) * (1 - gt) 274 | return torch.mean(loss) 275 | 276 | 277 | def _gather_feat(feat, ind, mask=None): 278 | dim = feat.size(2) 279 | ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) 280 | feat = feat.gather(1, ind) 281 | if mask is not None: 282 | mask = mask.unsqueeze(2).expand_as(feat) 283 | feat = feat[mask] 284 | feat = feat.view(-1, dim) 285 | return feat 286 | 287 | 288 | def _tranpose_and_gather_feat(feat, ind): 289 | feat = feat.permute(0, 2, 3, 1).contiguous() 290 | feat = feat.view(feat.size(0), -1, feat.size(3)) 291 | feat = _gather_feat(feat, ind) 292 | return feat 293 | 294 | 295 | class Ind2dRegL1Loss(nn.Module): 296 | def __init__(self, type='l1'): 297 | super(Ind2dRegL1Loss, self).__init__() 298 | if type == 'l1': 299 | self.loss = torch.nn.functional.l1_loss 300 | elif type == 'smooth_l1': 301 | self.loss = torch.nn.functional.smooth_l1_loss 302 | 303 | def forward(self, output, target, ind, ind_mask): 304 | """ind: [b, max_objs, max_parts]""" 305 | b, max_objs, max_parts = ind.shape 306 | ind = ind.view(b, max_objs * max_parts) 307 | pred = _tranpose_and_gather_feat(output, 308 | ind).view(b, max_objs, max_parts, 309 | output.size(1)) 310 | mask = ind_mask.unsqueeze(3).expand_as(pred) 311 | loss = self.loss(pred * mask, target * mask, reduction='sum') 312 | loss = loss / (mask.sum() + 1e-4) 313 | return loss 314 | 315 | 316 | class IndL1Loss1d(nn.Module): 317 | def __init__(self, type='l1'): 318 | super(IndL1Loss1d, self).__init__() 319 | if type == 'l1': 320 | self.loss = torch.nn.functional.l1_loss 321 | elif type == 'smooth_l1': 322 | self.loss = torch.nn.functional.smooth_l1_loss 323 | 324 | def forward(self, output, target, ind, weight): 325 | """ind: [b, n]""" 326 | output = _tranpose_and_gather_feat(output, ind) 327 | weight = weight.unsqueeze(2) 328 | loss = self.loss(output * weight, target * weight, reduction='sum') 329 | loss = loss / (weight.sum() * output.size(2) + 1e-4) 330 | return loss 331 | 332 | 333 | class GeoCrossEntropyLoss(nn.Module): 334 | def __init__(self): 335 | super(GeoCrossEntropyLoss, self).__init__() 336 | 337 | def forward(self, output, target, poly): 338 | output = torch.nn.functional.softmax(output, dim=1) 339 | output = torch.log(torch.clamp(output, min=1e-4)) 340 | poly = poly.view(poly.size(0), 4, poly.size(1) // 4, 2) 341 | target = target[..., None, None].expand(poly.size(0), poly.size(1), 1, 342 | poly.size(3)) 343 | target_poly = torch.gather(poly, 2, target) 344 | sigma = (poly[:, :, 0] - poly[:, :, 1]).pow(2).sum(2, keepdim=True) 345 | kernel = torch.exp(-(poly - target_poly).pow(2).sum(3) / (sigma / 3)) 346 | loss = -(output * kernel.transpose(2, 1)).sum(1).mean() 347 | return loss 348 | 349 | 350 | def load_model(net, 351 | optim, 352 | scheduler, 353 | recorder, 354 | model_dir, 355 | resume=True, 356 | epoch=-1): 357 | if not resume: 358 | os.system('rm -rf {}'.format(model_dir)) 359 | 360 | if not os.path.exists(model_dir): 361 | return 0 362 | 363 | pths = [ 364 | int(pth.split('.')[0]) for pth in os.listdir(model_dir) 365 | if pth != 'latest.pth' and pth != 'psnr_best.pth' and pth != 'ssim_best.pth' and pth != 'lpips_best.pth' 366 | ] 367 | if len(pths) == 0 and 'latest.pth' not in os.listdir(model_dir): 368 | return 0 369 | if epoch == -1: 370 | if 'latest.pth' in os.listdir(model_dir): 371 | pth = 'latest' 372 | else: 373 | pth = max(pths) 374 | else: 375 | pth = epoch 376 | print('load model: {}'.format(os.path.join(model_dir, 377 | '{}.pth'.format(pth)))) 378 | pretrained_model = torch.load( 379 | os.path.join(model_dir, '{}.pth'.format(pth)), 'cpu') 380 | net.load_state_dict(pretrained_model['net']) 381 | if 'optim' in pretrained_model: 382 | # optim.load_state_dict(pretrained_model['optim']) ###ft 383 | scheduler.load_state_dict(pretrained_model['scheduler']) 384 | recorder.load_state_dict(pretrained_model['recorder']) 385 | return pretrained_model['epoch'] + 1 386 | else: 387 | return 0 388 | 389 | 390 | def save_model(net, optim, scheduler, recorder, model_dir, epoch, custom=None, last=False): 391 | os.system('mkdir -p {}'.format(model_dir)) 392 | model = { 393 | 'net': net.state_dict(), 394 | 'optim': optim.state_dict(), 395 | 'scheduler': scheduler.state_dict(), 396 | 'recorder': recorder.state_dict(), 397 | 'epoch': epoch 398 | } 399 | if last: 400 | torch.save(model, os.path.join(model_dir, 'latest.pth')) 401 | else: 402 | torch.save(model, os.path.join(model_dir, '{}.pth'.format(epoch))) 403 | if custom is not None: 404 | torch.save(model, os.path.join(model_dir, f'{custom}.pth')) 405 | 406 | # remove previous pretrained model if the number of models is too big 407 | pths = [ 408 | int(pth.split('.')[0]) for pth in os.listdir(model_dir) 409 | if pth != 'latest.pth' and pth != 'psnr_best.pth' and pth != 'ssim_best.pth' and pth != 'lpips_best.pth' 410 | ] 411 | # if len(pths) <= 5: 412 | # return 413 | # os.system('rm {}'.format( 414 | # os.path.join(model_dir, '{}.pth'.format(min(pths))))) 415 | 416 | 417 | def load_network(net, model_dir, resume=True, epoch=-1, custom=None, strict=True): 418 | if not resume: 419 | return 0 420 | if not os.path.exists(model_dir): 421 | print(colored('pretrained model does not exist', 'red')) 422 | return 0 423 | 424 | if os.path.isdir(model_dir): 425 | pths = [ 426 | int(pth.split('.')[0]) for pth in os.listdir(model_dir) 427 | if pth != 'latest.pth' and pth != 'psnr_best.pth' and pth != 'ssim_best.pth' and pth != 'lpips_best.pth' 428 | ] 429 | if len(pths) == 0 and 'latest.pth' not in os.listdir(model_dir): 430 | return 0 431 | if epoch == -1: 432 | if custom is not None and custom+'.pth' in os.listdir(model_dir): 433 | pth = custom 434 | elif 'latest.pth' in os.listdir(model_dir): 435 | pth = 'latest' 436 | else: 437 | pth = max(pths) 438 | else: 439 | pth = epoch 440 | model_path = os.path.join(model_dir, '{}.pth'.format(pth)) 441 | else: 442 | model_path = model_dir 443 | print('load model: {}'.format(model_path)) 444 | pretrained_model = torch.load(model_path) 445 | net.load_state_dict(pretrained_model['net'], strict=strict) 446 | if 'epoch' in pretrained_model: 447 | return pretrained_model['epoch'] + 1 448 | else: 449 | return 0 450 | 451 | 452 | def remove_net_prefix(net, prefix): 453 | net_ = OrderedDict() 454 | for k in net.keys(): 455 | if k.startswith(prefix): 456 | net_[k[len(prefix):]] = net[k] 457 | else: 458 | net_[k] = net[k] 459 | return net_ 460 | 461 | 462 | def add_net_prefix(net, prefix): 463 | net_ = OrderedDict() 464 | for k in net.keys(): 465 | net_[prefix + k] = net[k] 466 | return net_ 467 | 468 | 469 | def replace_net_prefix(net, orig_prefix, prefix): 470 | net_ = OrderedDict() 471 | for k in net.keys(): 472 | if k.startswith(orig_prefix): 473 | net_[prefix + k[len(orig_prefix):]] = net[k] 474 | else: 475 | net_[k] = net[k] 476 | return net_ 477 | 478 | 479 | def remove_net_layer(net, layers): 480 | keys = list(net.keys()) 481 | for k in keys: 482 | for layer in layers: 483 | if k.startswith(layer): 484 | del net[k] 485 | return net 486 | 487 | def save_trained_config(cfg): 488 | if not cfg.resume: 489 | os.system('rm -rf ' + cfg.trained_config_dir+'/*') 490 | os.system('mkdir -p ' + cfg.trained_config_dir) 491 | train_cmd = ' '.join(sys.argv) 492 | train_cmd_path = os.path.join(cfg.trained_config_dir, 'train_cmd.txt') 493 | train_config_path = os.path.join(cfg.trained_config_dir, 'train_config.yaml') 494 | open(train_cmd_path, 'w').write(train_cmd) 495 | yaml.dump(cfg, open(train_config_path, 'w')) 496 | 497 | def load_pretrain(net, model_dir): 498 | model_dir = os.path.join(cfg.workspace, 'trained_model', cfg.task, model_dir) 499 | if not os.path.exists(model_dir): 500 | return 1 501 | 502 | pths = [int(pth.split('.')[0]) for pth in os.listdir(model_dir) if pth != 'latest.pth'] 503 | if len(pths) == 0 and 'latest.pth' not in os.listdir(model_dir): 504 | return 1 505 | 506 | if 'latest.pth' in os.listdir(model_dir): 507 | pth = 'latest' 508 | else: 509 | pth = max(pths) 510 | print('Load pretrain model: {}'.format(os.path.join(model_dir, '{}.pth'.format(pth)))) 511 | pretrained_model = torch.load(os.path.join(model_dir, '{}.pth'.format(pth)), 'cpu') 512 | net.load_state_dict(pretrained_model['net']) 513 | return 0 514 | 515 | def save_pretrain(net, task, model_dir): 516 | model_dir = os.path.join('data/trained_model', task, model_dir) 517 | os.system('mkdir -p ' + model_dir) 518 | model = {'net': net.state_dict()} 519 | torch.save(model, os.path.join(model_dir, 'latest.pth')) 520 | 521 | 522 | -------------------------------------------------------------------------------- /lib/utils/optimizer/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from bisect import bisect_right 2 | from collections import Counter 3 | 4 | import torch 5 | 6 | 7 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 8 | def __init__( 9 | self, 10 | optimizer, 11 | milestones, 12 | gamma=0.1, 13 | warmup_factor=1.0 / 3, 14 | warmup_iters=5, 15 | warmup_method="linear", 16 | last_epoch=-1, 17 | ): 18 | if not list(milestones) == sorted(milestones): 19 | raise ValueError( 20 | "Milestones should be a list of" " increasing integers. Got {}", 21 | milestones, 22 | ) 23 | 24 | if warmup_method not in ("constant", "linear"): 25 | raise ValueError( 26 | "Only 'constant' or 'linear' warmup_method accepted" 27 | "got {}".format(warmup_method) 28 | ) 29 | self.milestones = milestones 30 | self.gamma = gamma 31 | self.warmup_factor = warmup_factor 32 | self.warmup_iters = warmup_iters 33 | self.warmup_method = warmup_method 34 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 35 | 36 | def get_lr(self): 37 | warmup_factor = 1 38 | if self.last_epoch < self.warmup_iters: 39 | if self.warmup_method == "constant": 40 | warmup_factor = self.warmup_factor 41 | elif self.warmup_method == "linear": 42 | alpha = float(self.last_epoch) / self.warmup_iters 43 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 44 | return [ 45 | base_lr 46 | * warmup_factor 47 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 48 | for base_lr in self.base_lrs 49 | ] 50 | 51 | 52 | class MultiStepLR(torch.optim.lr_scheduler._LRScheduler): 53 | 54 | def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1): 55 | self.milestones = Counter(milestones) 56 | self.gamma = gamma 57 | super(MultiStepLR, self).__init__(optimizer, last_epoch) 58 | 59 | def get_lr(self): 60 | if self.last_epoch not in self.milestones: 61 | return [group['lr'] for group in self.optimizer.param_groups] 62 | return [group['lr'] * self.gamma ** self.milestones[self.last_epoch] 63 | for group in self.optimizer.param_groups] 64 | 65 | 66 | class ExponentialLR(torch.optim.lr_scheduler._LRScheduler): 67 | 68 | def __init__(self, optimizer, decay_epochs, gamma=0.1, last_epoch=-1): 69 | self.decay_epochs = decay_epochs 70 | self.gamma = gamma 71 | super(ExponentialLR, self).__init__(optimizer, last_epoch) 72 | 73 | def get_lr(self): 74 | return [base_lr * self.gamma ** (self.last_epoch / self.decay_epochs) 75 | for base_lr in self.base_lrs] 76 | -------------------------------------------------------------------------------- /lib/utils/optimizer/radam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | 6 | class RAdam(Optimizer): 7 | 8 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 9 | if not 0.0 <= lr: 10 | raise ValueError("Invalid learning rate: {}".format(lr)) 11 | if not 0.0 <= eps: 12 | raise ValueError("Invalid epsilon value: {}".format(eps)) 13 | if not 0.0 <= betas[0] < 1.0: 14 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 15 | if not 0.0 <= betas[1] < 1.0: 16 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 17 | 18 | self.degenerated_to_sgd = degenerated_to_sgd 19 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): 20 | for param in params: 21 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): 22 | param['buffer'] = [[None, None, None] for _ in range(10)] 23 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) 24 | super(RAdam, self).__init__(params, defaults) 25 | 26 | def __setstate__(self, state): 27 | super(RAdam, self).__setstate__(state) 28 | 29 | def step(self, closure=None): 30 | 31 | loss = None 32 | if closure is not None: 33 | loss = closure() 34 | 35 | for group in self.param_groups: 36 | 37 | for p in group['params']: 38 | if p.grad is None: 39 | continue 40 | grad = p.grad.data.float() 41 | if grad.is_sparse: 42 | raise RuntimeError('RAdam does not support sparse gradients') 43 | 44 | p_data_fp32 = p.data.float() 45 | 46 | state = self.state[p] 47 | 48 | if len(state) == 0: 49 | state['step'] = 0 50 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 51 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 52 | else: 53 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 54 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 55 | 56 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 57 | beta1, beta2 = group['betas'] 58 | 59 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 60 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 61 | 62 | state['step'] += 1 63 | buffered = group['buffer'][int(state['step'] % 10)] 64 | if state['step'] == buffered[0]: 65 | N_sma, step_size = buffered[1], buffered[2] 66 | else: 67 | buffered[0] = state['step'] 68 | beta2_t = beta2 ** state['step'] 69 | N_sma_max = 2 / (1 - beta2) - 1 70 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 71 | buffered[1] = N_sma 72 | 73 | # more conservative since it's an approximated value 74 | if N_sma >= 5: 75 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 76 | elif self.degenerated_to_sgd: 77 | step_size = 1.0 / (1 - beta1 ** state['step']) 78 | else: 79 | step_size = -1 80 | buffered[2] = step_size 81 | 82 | # more conservative since it's an approximated value 83 | if N_sma >= 5: 84 | if group['weight_decay'] != 0: 85 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 86 | denom = exp_avg_sq.sqrt().add_(group['eps']) 87 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 88 | p.data.copy_(p_data_fp32) 89 | elif step_size > 0: 90 | if group['weight_decay'] != 0: 91 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 92 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 93 | p.data.copy_(p_data_fp32) 94 | 95 | return loss 96 | 97 | class PlainRAdam(Optimizer): 98 | 99 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 100 | if not 0.0 <= lr: 101 | raise ValueError("Invalid learning rate: {}".format(lr)) 102 | if not 0.0 <= eps: 103 | raise ValueError("Invalid epsilon value: {}".format(eps)) 104 | if not 0.0 <= betas[0] < 1.0: 105 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 106 | if not 0.0 <= betas[1] < 1.0: 107 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 108 | 109 | self.degenerated_to_sgd = degenerated_to_sgd 110 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 111 | 112 | super(PlainRAdam, self).__init__(params, defaults) 113 | 114 | def __setstate__(self, state): 115 | super(PlainRAdam, self).__setstate__(state) 116 | 117 | def step(self, closure=None): 118 | 119 | loss = None 120 | if closure is not None: 121 | loss = closure() 122 | 123 | for group in self.param_groups: 124 | 125 | for p in group['params']: 126 | if p.grad is None: 127 | continue 128 | grad = p.grad.data.float() 129 | if grad.is_sparse: 130 | raise RuntimeError('RAdam does not support sparse gradients') 131 | 132 | p_data_fp32 = p.data.float() 133 | 134 | state = self.state[p] 135 | 136 | if len(state) == 0: 137 | state['step'] = 0 138 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 139 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 140 | else: 141 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 142 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 143 | 144 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 145 | beta1, beta2 = group['betas'] 146 | 147 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 148 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 149 | 150 | state['step'] += 1 151 | beta2_t = beta2 ** state['step'] 152 | N_sma_max = 2 / (1 - beta2) - 1 153 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 154 | 155 | 156 | # more conservative since it's an approximated value 157 | if N_sma >= 5: 158 | if group['weight_decay'] != 0: 159 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 160 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 161 | denom = exp_avg_sq.sqrt().add_(group['eps']) 162 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 163 | p.data.copy_(p_data_fp32) 164 | elif self.degenerated_to_sgd: 165 | if group['weight_decay'] != 0: 166 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 167 | step_size = group['lr'] / (1 - beta1 ** state['step']) 168 | p_data_fp32.add_(-step_size, exp_avg) 169 | p.data.copy_(p_data_fp32) 170 | 171 | return loss 172 | 173 | 174 | class AdamW(Optimizer): 175 | 176 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0): 177 | if not 0.0 <= lr: 178 | raise ValueError("Invalid learning rate: {}".format(lr)) 179 | if not 0.0 <= eps: 180 | raise ValueError("Invalid epsilon value: {}".format(eps)) 181 | if not 0.0 <= betas[0] < 1.0: 182 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 183 | if not 0.0 <= betas[1] < 1.0: 184 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 185 | 186 | defaults = dict(lr=lr, betas=betas, eps=eps, 187 | weight_decay=weight_decay, warmup = warmup) 188 | super(AdamW, self).__init__(params, defaults) 189 | 190 | def __setstate__(self, state): 191 | super(AdamW, self).__setstate__(state) 192 | 193 | def step(self, closure=None): 194 | loss = None 195 | if closure is not None: 196 | loss = closure() 197 | 198 | for group in self.param_groups: 199 | 200 | for p in group['params']: 201 | if p.grad is None: 202 | continue 203 | grad = p.grad.data.float() 204 | if grad.is_sparse: 205 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 206 | 207 | p_data_fp32 = p.data.float() 208 | 209 | state = self.state[p] 210 | 211 | if len(state) == 0: 212 | state['step'] = 0 213 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 214 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 215 | else: 216 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 217 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 218 | 219 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 220 | beta1, beta2 = group['betas'] 221 | 222 | state['step'] += 1 223 | 224 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 225 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 226 | 227 | denom = exp_avg_sq.sqrt().add_(group['eps']) 228 | bias_correction1 = 1 - beta1 ** state['step'] 229 | bias_correction2 = 1 - beta2 ** state['step'] 230 | 231 | if group['warmup'] > state['step']: 232 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] 233 | else: 234 | scheduled_lr = group['lr'] 235 | 236 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 237 | 238 | if group['weight_decay'] != 0: 239 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 240 | 241 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 242 | 243 | p.data.copy_(p_data_fp32) 244 | 245 | return loss 246 | 247 | -------------------------------------------------------------------------------- /lib/utils/ply_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | from PIL import Image 4 | from plyfile import PlyData, PlyElement 5 | 6 | def read_pfm(filename): 7 | file = open(filename, 'rb') 8 | color = None 9 | width = None 10 | height = None 11 | scale = None 12 | endian = None 13 | 14 | header = file.readline().decode('utf-8').rstrip() 15 | if header == 'PF': 16 | color = True 17 | elif header == 'Pf': 18 | color = False 19 | else: 20 | raise Exception('Not a PFM file.') 21 | 22 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 23 | if dim_match: 24 | width, height = map(int, dim_match.groups()) 25 | else: 26 | raise Exception('Malformed PFM header.') 27 | 28 | scale = float(file.readline().rstrip()) 29 | if scale < 0: # little-endian 30 | endian = '<' 31 | scale = -scale 32 | else: 33 | endian = '>' # big-endian 34 | 35 | data = np.fromfile(file, endian + 'f') 36 | shape = (height, width, 3) if color else (height, width) 37 | 38 | data = np.reshape(data, shape) 39 | data = np.flipud(data) 40 | file.close() 41 | return data, scale 42 | 43 | def read_camera_parameters(filename): 44 | with open(filename) as f: 45 | lines = f.readlines() 46 | lines = [line.rstrip() for line in lines] 47 | # extrinsics: line [1,5), 4x4 matrix 48 | extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) 49 | # intrinsics: line [7-10), 3x3 matrix 50 | intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) 51 | return intrinsics, extrinsics 52 | 53 | def read_img(filename): 54 | img = Image.open(filename) 55 | # scale 0~255 to 0~1 56 | np_img = np.array(img, dtype=np.float32) / 255. 57 | return np_img 58 | 59 | 60 | def storePly(path, xyz, rgb): 61 | # Define the dtype for the structured array 62 | dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), 63 | ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), 64 | ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] 65 | 66 | normals = np.zeros_like(xyz) 67 | 68 | elements = np.empty(xyz.shape[0], dtype=dtype) 69 | attributes = np.concatenate((xyz, normals, rgb), axis=1) 70 | elements[:] = list(map(tuple, attributes)) 71 | 72 | # Create the PlyData object and write to file 73 | vertex_element = PlyElement.describe(elements, 'vertex') 74 | ply_data = PlyData([vertex_element]) 75 | ply_data.write(path) -------------------------------------------------------------------------------- /lib/utils/rend_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | def normalize(x): 5 | return x / np.linalg.norm(x) 6 | 7 | def ptstocam(pts, c2w): 8 | tt = np.matmul(c2w[:3, :3].T, (pts-c2w[:3, 3])[..., np.newaxis])[..., 0] 9 | return tt 10 | 11 | def viewmatrix(z, up, pos): 12 | vec2 = normalize(z) 13 | vec0_avg = up 14 | vec1 = normalize(np.cross(vec2, vec0_avg)) 15 | vec0 = normalize(np.cross(vec1, vec2)) 16 | m = np.stack([vec0, vec1, vec2, pos], 1) 17 | return m 18 | 19 | def gen_path(RT, center=None, num_views=100): 20 | lower_row = np.array([[0., 0., 0., 1.]]) 21 | 22 | # transfer RT to camera_to_world matrix 23 | RT = np.array(RT) 24 | RT[:] = np.linalg.inv(RT[:]) 25 | 26 | RT = np.concatenate([RT[:, :, 1:2], RT[:, :, 0:1], 27 | -RT[:, :, 2:3], RT[:, :, 3:4]], 2) 28 | 29 | up = normalize(RT[:, :3, 0].sum(0)) # average up vector 30 | z = normalize(RT[0, :3, 2]) 31 | vec1 = normalize(np.cross(z, up)) 32 | vec2 = normalize(np.cross(up, vec1)) 33 | z_off = 0 34 | 35 | if center is None: 36 | center = RT[:, :3, 3].mean(0) 37 | z_off = 1.3 38 | 39 | c2w = np.stack([up, vec1, vec2, center], 1) 40 | 41 | # get radii for spiral path 42 | tt = ptstocam(RT[:, :3, 3], c2w).T 43 | rads = np.percentile(np.abs(tt), 80, -1) 44 | rads = rads * 1.3 45 | rads = np.array(list(rads) + [1.]) 46 | 47 | render_w2c = [] 48 | for theta in np.linspace(0., 2 * np.pi, num_views + 1)[:-1]: 49 | # camera position 50 | cam_pos = np.array([0, np.sin(theta), np.cos(theta), 1] * rads) 51 | cam_pos_world = np.dot(c2w[:3, :4], cam_pos) 52 | # z axis 53 | z = normalize(cam_pos_world - 54 | np.dot(c2w[:3, :4], np.array([z_off, 0, 0, 1.]))) 55 | # vector -> 3x4 matrix (camera_to_world) 56 | mat = viewmatrix(z, up, cam_pos_world) 57 | 58 | mat = np.concatenate([mat[:, 1:2], mat[:, 0:1], 59 | -mat[:, 2:3], mat[:, 3:4]], 1) 60 | mat = np.concatenate([mat, lower_row], 0) 61 | mat = np.linalg.inv(mat) 62 | render_w2c.append(mat) 63 | 64 | return render_w2c 65 | 66 | def create_center_radius(center, radius=5., up='y', ranges=[0, 360, 36], angle_x=0, **kwargs): 67 | center = np.array(center).reshape(1, 3) 68 | thetas = np.deg2rad(np.linspace(*ranges)) 69 | st = np.sin(thetas) 70 | ct = np.cos(thetas) 71 | zero = np.zeros_like(st) 72 | Rotx = cv2.Rodrigues(np.deg2rad(angle_x) * np.array([1., 0., 0.]))[0] 73 | if up == 'z': 74 | center = np.stack([radius*ct, radius*st, zero], axis=1) + center 75 | R = np.stack([-st, ct, zero, zero, zero, zero-1, -ct, -st, zero], axis=-1) 76 | elif up == 'y': 77 | center = np.stack([radius*ct, zero, radius*st, ], axis=1) + center 78 | R = np.stack([ 79 | +st, zero, -ct, 80 | zero, zero-1, zero, 81 | -ct, zero, -st], axis=-1) 82 | R = R.reshape(-1, 3, 3) 83 | R = np.einsum('ab,fbc->fac', Rotx, R) 84 | center = center.reshape(-1, 3, 1) 85 | T = - R @ center 86 | RT = np.dstack([R, T]) 87 | return RT 88 | 89 | 90 | def gen_path(RT, center=None, radius=3.2, up='z', ranges=[0, 360, 90], angle_x=27, **kwargs): 91 | ranges = [0, 360, kwargs['num_views']] 92 | c2ws = np.linalg.inv(RT[:]) 93 | if center is None: 94 | center = c2ws[:, :3, 3].mean(0) 95 | center[0], center[1] = 0., 0. 96 | 97 | RTs = [] 98 | center = np.array(center).reshape(1, 3) 99 | thetas = np.deg2rad(np.linspace(*ranges)) 100 | st = np.sin(thetas) 101 | ct = np.cos(thetas) 102 | zero = np.zeros_like(st) 103 | Rotx = cv2.Rodrigues(np.deg2rad(angle_x) * np.array([1., 0., 0.]))[0] 104 | if up == 'z': 105 | center = np.stack([radius*ct, radius*st, zero], axis=1) + center 106 | R = np.stack([-st, ct, zero, zero, zero, zero-1, -ct, -st, zero], axis=-1) 107 | elif up == 'y': 108 | center = np.stack([radius*ct, zero, radius*st, ], axis=1) + center 109 | R = np.stack([ 110 | +st, zero, -ct, 111 | zero, zero-1, zero, 112 | -ct, zero, -st], axis=-1) 113 | R = R.reshape(-1, 3, 3) 114 | R = np.einsum('ab,fbc->fac', Rotx, R) 115 | center = center.reshape(-1, 3, 1) 116 | T = - R @ center 117 | RT = np.dstack([R, T]) 118 | RT_bottom = np.zeros_like(RT[:, :1]) 119 | RT_bottom[:, :, 3] = 1 120 | # __import__('ipdb').set_trace() 121 | ext = np.concatenate([RT, RT_bottom], axis=1) 122 | c2w = np.linalg.inv(ext) 123 | # __import__('ipdb').set_trace() 124 | # import matplotlib.pyplot as plt 125 | # plt.plot(c2ws[:, 0, 3], c2ws[:, 1, 3], '.') 126 | # plt.plot(c2w[:, 0, 3], c2w[:, 1, 3], '.') 127 | # plt.show() 128 | return ext 129 | 130 | def gen_nerf_path(c2ws, depth_ranges, rads_scale=.5, N_views=60): 131 | c2w = poses_avg(c2ws) 132 | up = normalize(c2ws[:, :3, 1].sum(0)) 133 | 134 | close_depth, inf_depth = depth_ranges 135 | dt = .75 136 | mean_dz = 1./(( (1.-dt)/close_depth + dt/inf_depth )) 137 | focal = mean_dz 138 | 139 | shrink_factor = .8 140 | zdelta = close_depth * .2 141 | tt = c2ws[:, :3, 3] = c2w[:3, 3][None] 142 | rads = np.percentile(np.abs(tt), 70, 0)*rads_scale 143 | 144 | render_poses = render_path_spiral(c2w, up, rads, focal, zdelta, zrate=.5, N=N_views) 145 | return render_poses 146 | 147 | def poses_avg(poses): 148 | center = poses[:, :3, 3].mean(0) 149 | vec2 = normalize(poses[:, :3, 2].sum(0)) 150 | up = poses[:, :3, 1].sum(0) 151 | c2w = viewmatrix(vec2, up, center) 152 | return c2w 153 | 154 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=120): 155 | render_poses = [] 156 | rads = np.array(list(rads) + [1.]) 157 | 158 | for theta in np.linspace(0., 2. * np.pi * N_rots, N+1)[:-1]: 159 | c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads) 160 | z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.]))) 161 | render_poses.append(viewmatrix(z, up, c)) 162 | return render_poses 163 | -------------------------------------------------------------------------------- /lib/utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | def get_bound_corners(bounds): 5 | min_x, min_y, min_z = bounds[0] 6 | max_x, max_y, max_z = bounds[1] 7 | corners_3d = np.array([ 8 | [min_x, min_y, min_z], 9 | [min_x, min_y, max_z], 10 | [min_x, max_y, min_z], 11 | [min_x, max_y, max_z], 12 | [max_x, min_y, min_z], 13 | [max_x, min_y, max_z], 14 | [max_x, max_y, min_z], 15 | [max_x, max_y, max_z], 16 | ]) 17 | return corners_3d 18 | 19 | def project(xyz, K, RT): 20 | """ 21 | xyz: [N, 3] 22 | K: [3, 3] 23 | RT: [3, 4] 24 | """ 25 | xyz = np.dot(xyz, RT[:, :3].T) + RT[:, 3:].T 26 | 27 | xyz = np.dot(xyz, K.T) 28 | xy = xyz[:, :2] / xyz[:, 2:] 29 | return xy 30 | 31 | row_col_ = { 32 | 2: (2, 1), 33 | 7: (2, 4), 34 | 8: (2, 4), 35 | 9: (3, 3), 36 | 26: (4, 7) 37 | } 38 | 39 | row_col_square = { 40 | 2: (2, 1), 41 | 7: (3, 3), 42 | 8: (3, 3), 43 | 9: (3, 3), 44 | 26: (5, 5) 45 | } 46 | 47 | def get_row_col(l, square): 48 | if square and l in row_col_square.keys(): 49 | return row_col_square[l] 50 | if l in row_col_.keys(): 51 | return row_col_[l] 52 | else: 53 | from math import sqrt 54 | row = int(sqrt(l) + 0.5) 55 | col = int(l/ row + 0.5) 56 | if row*col col: 59 | row, col = col, row 60 | return row, col 61 | 62 | def merge(images, row=-1, col=-1, resize=False, ret_range=False, square=False, **kwargs): 63 | if row == -1 and col == -1: 64 | row, col = get_row_col(len(images), square) 65 | height = images[0].shape[0] 66 | width = images[0].shape[1] 67 | # special case 68 | if height > width: 69 | if len(images) == 3: 70 | row, col = 1, 3 71 | if len(images[0].shape) > 2: 72 | ret_img = np.zeros((height * row, width * col, images[0].shape[2]), dtype=np.uint8) + 255 73 | else: 74 | ret_img = np.zeros((height * row, width * col), dtype=np.uint8) + 255 75 | ranges = [] 76 | for i in range(row): 77 | for j in range(col): 78 | if i*col + j >= len(images): 79 | break 80 | img = images[i * col + j] 81 | # resize the image size 82 | img = cv2.resize(img, (width, height)) 83 | ret_img[height * i: height * (i+1), width * j: width * (j+1)] = img 84 | ranges.append((width*j, height*i, width*(j+1), height*(i+1))) 85 | if resize: 86 | min_height = 1000 87 | if ret_img.shape[0] > min_height: 88 | scale = min_height/ret_img.shape[0] 89 | ret_img = cv2.resize(ret_img, None, fx=scale, fy=scale) 90 | if ret_range: 91 | return ret_img, ranges 92 | return ret_img 93 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyyaml 2 | opencv-python 3 | imgaug 4 | plyfile 5 | tqdm 6 | kornia 7 | ipdb 8 | lpips 9 | tensorboardX 10 | glfw 11 | pyglm 12 | pyopengl 13 | imgui 14 | termcolor 15 | trimesh 16 | scikit-image==0.19.0 17 | imageio==2.27.0 18 | imageio[ffmpeg] 19 | imageio[pyav] -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from lib.config import cfg, args 2 | from lib.utils.ply_utils import * 3 | import numpy as np 4 | import os 5 | import glob 6 | 7 | def run_dataset(): 8 | from lib.datasets import make_data_loader 9 | import tqdm 10 | 11 | cfg.train.num_workers = 0 12 | data_loader = make_data_loader(cfg, is_train=False) 13 | for batch in tqdm.tqdm(data_loader): 14 | pass 15 | 16 | def run_network(): 17 | from lib.networks import make_network 18 | from lib.datasets import make_data_loader 19 | from lib.utils.net_utils import load_network 20 | from lib.utils.data_utils import to_cuda 21 | import tqdm 22 | import torch 23 | import time 24 | 25 | network = make_network(cfg).cuda() 26 | load_network(network, cfg.trained_model_dir, epoch=cfg.test.epoch) 27 | network.eval() 28 | 29 | data_loader = make_data_loader(cfg, is_train=False) 30 | total_time = 0 31 | for batch in tqdm.tqdm(data_loader): 32 | batch = to_cuda(batch) 33 | with torch.no_grad(): 34 | torch.cuda.synchronize() 35 | start = time.time() 36 | network(batch) 37 | torch.cuda.synchronize() 38 | total_time += time.time() - start 39 | print(total_time / len(data_loader)) 40 | 41 | def run_evaluate(): 42 | from lib.datasets import make_data_loader 43 | from lib.evaluators import make_evaluator 44 | import tqdm 45 | import torch 46 | from lib.networks import make_network 47 | from lib.utils import net_utils 48 | import time 49 | 50 | network = make_network(cfg).cuda() 51 | net_utils.load_network(network, 52 | cfg.trained_model_dir, 53 | resume=cfg.resume, 54 | epoch=cfg.test.epoch) 55 | network.eval() 56 | 57 | data_loader = make_data_loader(cfg, is_train=False) 58 | evaluator = make_evaluator(cfg) 59 | net_time = [] 60 | for batch in tqdm.tqdm(data_loader): 61 | for k in batch: 62 | if k != 'meta': 63 | if k == 'rendering_video_meta': 64 | for i in range(len(batch[k])): 65 | for v in batch[k][i]: 66 | batch[k][i][v] = batch[k][i][v].cuda() 67 | else: 68 | batch[k] = batch[k].cuda() 69 | if cfg.save_video: 70 | with torch.no_grad(): 71 | network(batch) 72 | else: 73 | with torch.no_grad(): 74 | torch.cuda.synchronize() 75 | start_time = time.time() 76 | output = network(batch) 77 | torch.cuda.synchronize() 78 | end_time = time.time() 79 | net_time.append(end_time - start_time) 80 | evaluator.evaluate(output, batch) 81 | 82 | if not cfg.save_video: 83 | evaluator.summarize() 84 | if len(net_time) > 1: 85 | # print('net_time: ', np.mean(net_time[1:])) 86 | print('FPS: ', 1./np.mean(net_time[1:])) 87 | else: 88 | # print('net_time: ', np.mean(net_time)) 89 | print('FPS: ', 1./np.mean(net_time)) 90 | 91 | 92 | if cfg.save_ply: 93 | dataset_name = cfg.train_dataset_module.split('.')[-2] 94 | ply_dir = os.path.join(cfg.result_dir, 'pointclouds', dataset_name) 95 | for item in os.listdir(ply_dir): 96 | data_dir = os.path.join(ply_dir, item) 97 | img_dir = os.path.join(data_dir, 'images') 98 | depth_dir = os.path.join(data_dir, 'depth') 99 | cam_dir = os.path.join(data_dir, 'cam') 100 | img_ls = glob.glob(os.path.join(img_dir, '*.png')) 101 | img_name = [os.path.basename(im).split('.')[0] for im in img_ls] 102 | 103 | # for the final point cloud 104 | vertexs = [] 105 | vertex_colors = [] 106 | 107 | for name in img_name: 108 | ref_name = name 109 | 110 | ref_intrinsics, ref_extrinsics = read_camera_parameters(os.path.join(cam_dir, ref_name+'.txt')) 111 | ref_img = read_img(os.path.join(img_dir, ref_name+'.png')) 112 | ref_depth_est = read_pfm(os.path.join(depth_dir, ref_name+'.pfm'))[0] 113 | 114 | height, width = ref_depth_est.shape[:2] 115 | x, y = np.meshgrid(np.arange(0, width), np.arange(0, height)) 116 | x, y = x.reshape(-1), y.reshape(-1) 117 | depth = ref_depth_est.reshape(-1) 118 | color = ref_img.reshape(-1,3) 119 | xyz_ref = np.matmul(np.linalg.inv(ref_intrinsics), 120 | np.vstack((x, y, np.ones_like(x))) * depth) 121 | xyz_world = np.matmul(np.linalg.inv(ref_extrinsics), 122 | np.vstack((xyz_ref, np.ones_like(x))))[:3] 123 | vertexs.append(xyz_world.transpose((1, 0))) 124 | vertex_colors.append((color * 255).astype(np.uint8)) 125 | vertexs = np.concatenate(vertexs, axis=0) 126 | vertex_colors = np.concatenate(vertex_colors, axis=0) 127 | scene = os.path.basename(data_dir) 128 | ply_path = os.path.join(data_dir, f'{scene}.ply') 129 | print(f'saving {ply_path}') 130 | storePly(ply_path, vertexs, vertex_colors) 131 | 132 | ## point cloud --> mesh 133 | import open3d as o3d 134 | import trimesh 135 | pcd = o3d.io.read_point_cloud(ply_path) 136 | pcd.estimate_normals() 137 | 138 | # estimate radius for rolling ball 139 | distances = pcd.compute_nearest_neighbor_distance() 140 | avg_dist = np.mean(distances) 141 | radius = 1.5 * avg_dist 142 | 143 | mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting( 144 | pcd, 145 | o3d.utility.DoubleVector([radius, radius * 2])) 146 | 147 | # create the triangular mesh with the vertices and faces from open3d 148 | tri_mesh = trimesh.Trimesh(np.asarray(mesh.vertices), np.asarray(mesh.triangles), 149 | vertex_normals=np.asarray(mesh.vertex_normals)) 150 | 151 | trimesh.convex.is_convex(tri_mesh) 152 | ply_path = os.path.join(data_dir, f'{scene}_mesh.ply') 153 | trimesh.exchange.export.export_mesh(tri_mesh, ply_path) 154 | 155 | 156 | 157 | if __name__ == '__main__': 158 | globals()['run_' + args.type]() 159 | -------------------------------------------------------------------------------- /train_net.py: -------------------------------------------------------------------------------- 1 | from lib.config import cfg, args 2 | from lib.networks import make_network 3 | from lib.train import make_trainer, make_optimizer, make_lr_scheduler, make_recorder, set_lr_scheduler 4 | from lib.datasets import make_data_loader 5 | from lib.utils.net_utils import load_model, save_model, load_network, save_trained_config, load_pretrain 6 | from lib.evaluators import make_evaluator 7 | import torch.multiprocessing 8 | import torch 9 | import torch.distributed as dist 10 | import os 11 | # torch.autograd.set_detect_anomaly(True) 12 | 13 | if cfg.fix_random: 14 | torch.manual_seed(0) 15 | torch.backends.cudnn.deterministic = True 16 | torch.backends.cudnn.benchmark = False 17 | 18 | 19 | def train(cfg, network): 20 | train_loader = make_data_loader(cfg, 21 | is_train=True, 22 | is_distributed=cfg.distributed, 23 | max_iter=cfg.ep_iter) 24 | if cfg.skip_eval: 25 | val_loader = None 26 | else: 27 | val_loader = make_data_loader(cfg, is_train=False) 28 | trainer = make_trainer(cfg, network, train_loader) 29 | optimizer = make_optimizer(cfg, network) 30 | scheduler = make_lr_scheduler(cfg, optimizer) 31 | recorder = make_recorder(cfg) 32 | evaluator = make_evaluator(cfg) 33 | 34 | begin_epoch = load_model(network, 35 | optimizer, 36 | scheduler, 37 | recorder, 38 | cfg.trained_model_dir, 39 | resume=cfg.resume) 40 | if begin_epoch == 0 and cfg.pretrain != '': 41 | load_pretrain(network, cfg.pretrain) 42 | 43 | 44 | set_lr_scheduler(cfg, scheduler) 45 | psnr_best, ssim_best, lpips_best = 0, 0, 10 46 | for epoch in range(begin_epoch, cfg.train.epoch): 47 | recorder.epoch = epoch 48 | if cfg.distributed: 49 | train_loader.batch_sampler.sampler.set_epoch(epoch) 50 | 51 | train_loader.dataset.epoch = epoch 52 | 53 | trainer.train(epoch, train_loader, optimizer, recorder) 54 | scheduler.step() 55 | 56 | if (epoch + 1) % cfg.save_ep == 0 and cfg.local_rank == 0: 57 | save_model(network, optimizer, scheduler, recorder, 58 | cfg.trained_model_dir, epoch) 59 | 60 | if (epoch + 1) % cfg.save_latest_ep == 0 and cfg.local_rank == 0: 61 | save_model(network, 62 | optimizer, 63 | scheduler, 64 | recorder, 65 | cfg.trained_model_dir, 66 | epoch, 67 | last=True) 68 | 69 | if not cfg.skip_eval and (epoch + 1) % cfg.eval_ep == 0 and cfg.local_rank == 0: 70 | result = trainer.val(epoch, val_loader, evaluator, recorder) 71 | psnr = result['psnr'] 72 | ssim = result['ssim'] 73 | lpips = result['lpips'] 74 | if psnr > psnr_best: 75 | psnr_best = psnr 76 | save_model(network, 77 | optimizer, 78 | scheduler, 79 | recorder, 80 | cfg.trained_model_dir, 81 | epoch, 82 | custom='psnr_best') 83 | if ssim > ssim_best: 84 | ssim_best = ssim 85 | save_model(network, 86 | optimizer, 87 | scheduler, 88 | recorder, 89 | cfg.trained_model_dir, 90 | epoch, 91 | custom='ssim_best') 92 | if lpips < lpips_best: 93 | lpips_best = lpips 94 | save_model(network, 95 | optimizer, 96 | scheduler, 97 | recorder, 98 | cfg.trained_model_dir, 99 | epoch, 100 | custom='lpips_best') 101 | print(f'psnr_best: {psnr_best:.2f}, ssim_best: {ssim_best:.3f}, lpips_best: {lpips_best:.3f}') 102 | 103 | 104 | 105 | return network 106 | 107 | 108 | def test(cfg, network): 109 | trainer = make_trainer(cfg, network) 110 | val_loader = make_data_loader(cfg, is_train=False) 111 | evaluator = make_evaluator(cfg) 112 | epoch = load_network(network, 113 | cfg.trained_model_dir, 114 | resume=cfg.resume, 115 | epoch=cfg.test.epoch) 116 | trainer.val(epoch, val_loader, evaluator) 117 | 118 | def synchronize(): 119 | """ 120 | Helper function to synchronize (barrier) among all processes when 121 | using distributed training 122 | """ 123 | if not dist.is_available(): 124 | return 125 | if not dist.is_initialized(): 126 | return 127 | world_size = dist.get_world_size() 128 | if world_size == 1: 129 | return 130 | dist.barrier() 131 | 132 | def main(): 133 | if cfg.distributed: 134 | cfg.local_rank = int(os.environ['RANK']) % torch.cuda.device_count() 135 | torch.cuda.set_device(cfg.local_rank) 136 | torch.distributed.init_process_group(backend="nccl", 137 | init_method="env://") 138 | synchronize() 139 | 140 | network = make_network(cfg) 141 | if args.test: 142 | test(cfg, network) 143 | else: 144 | train(cfg, network) 145 | if cfg.local_rank == 0: 146 | print('Success!') 147 | print('='*80) 148 | os.system('kill -9 {}'.format(os.getpid())) 149 | 150 | 151 | if __name__ == "__main__": 152 | main() 153 | --------------------------------------------------------------------------------