├── .gitignore ├── LICENSE ├── README.md ├── configs ├── default.yml ├── inb │ ├── inb_377.yaml │ ├── inb_386.yaml │ ├── inb_387.yaml │ ├── inb_390.yaml │ ├── inb_392.yaml │ ├── inb_393.yaml │ ├── inb_394.yaml │ ├── inb_lan.yaml │ ├── inb_marc.yaml │ ├── inb_olek.yaml │ └── inb_vlad.yaml └── monocular.yml ├── docs ├── install.md ├── media │ └── inb.gif ├── preprocess-CN.md └── preprocess.md ├── environment.yml ├── lib ├── __init__.py ├── config │ ├── __init__.py │ ├── config.py │ └── yacs.py ├── csrc │ ├── pointnet2 │ │ ├── pointnet2_modules.py │ │ ├── pointnet2_utils.py │ │ ├── pytorch_utils.py │ │ ├── setup.py │ │ └── src │ │ │ ├── ball_query.cpp │ │ │ ├── ball_query_gpu.cu │ │ │ ├── ball_query_gpu.h │ │ │ ├── cuda_utils.h │ │ │ ├── group_points.cpp │ │ │ ├── group_points_gpu.cu │ │ │ ├── group_points_gpu.h │ │ │ ├── interpolate.cpp │ │ │ ├── interpolate_gpu.cu │ │ │ ├── interpolate_gpu.h │ │ │ ├── pointnet2_api.cpp │ │ │ ├── sampling.cpp │ │ │ ├── sampling_gpu.cu │ │ │ └── sampling_gpu.h │ └── torchsearchsorted │ │ ├── .gitignore │ │ ├── LICENSE │ │ ├── README.md │ │ ├── examples │ │ ├── benchmark.py │ │ └── test.py │ │ ├── setup.py │ │ ├── src │ │ ├── cpu │ │ │ ├── searchsorted_cpu_wrapper.cpp │ │ │ └── searchsorted_cpu_wrapper.h │ │ ├── cuda │ │ │ ├── searchsorted_cuda_kernel.cu │ │ │ ├── searchsorted_cuda_kernel.h │ │ │ ├── searchsorted_cuda_wrapper.cpp │ │ │ └── searchsorted_cuda_wrapper.h │ │ └── torchsearchsorted │ │ │ ├── __init__.py │ │ │ ├── searchsorted.py │ │ │ └── utils.py │ │ └── test │ │ ├── conftest.py │ │ └── test_searchsorted.py ├── datasets │ ├── __init__.py │ ├── collate_batch.py │ ├── h36m │ │ ├── tpose_dataset.py │ │ ├── tpose_dataset_mesh.py │ │ └── tpose_novel_view_dataset.py │ ├── make_dataset.py │ ├── samplers.py │ └── transforms.py ├── evaluators │ ├── __init__.py │ ├── if_nerf.py │ └── make_evaluator.py ├── networks │ ├── __init__.py │ ├── bw_deform │ │ ├── inb_part_network_multiassign.py │ │ └── part_base_network.py │ ├── deformers │ │ └── uv_deformer.py │ ├── embedder.py │ ├── embedders │ │ ├── freq_embedder.py │ │ └── part_base_embedder.py │ ├── make_network.py │ └── renderer │ │ ├── __init__.py │ │ ├── inb_renderer.py │ │ ├── make_renderer.py │ │ ├── nerf_net_utils.py │ │ └── pose_mesh_renderer.py ├── train │ ├── __init__.py │ ├── optimizer.py │ ├── recorder.py │ ├── scheduler.py │ └── trainers │ │ ├── __init__.py │ │ ├── crit.py │ │ ├── inb_trainer.py │ │ ├── loss │ │ ├── __init__.py │ │ ├── fourier_loss.py │ │ ├── perceptual_loss.py │ │ └── tv_image_loss.py │ │ ├── make_trainer.py │ │ └── trainer.py ├── utils │ ├── base_utils.py │ ├── blend_utils.py │ ├── debug_utils.py │ ├── if_nerf │ │ ├── if_nerf_data_utils.py │ │ └── if_nerf_net_utils.py │ ├── img_utils.py │ ├── loss_utils.py │ ├── net_utils.py │ ├── optimizer │ │ ├── lr_scheduler.py │ │ └── radam.py │ └── render_utils.py └── visualizers │ ├── __init__.py │ ├── if_nerf.py │ ├── if_nerf_demo.py │ └── make_visualizer.py ├── requirements.txt ├── run.py ├── scripts ├── eval_monocap.sh └── eval_zjumocap.sh ├── tools ├── cropschp.py ├── easymocap2instant-nvr.py ├── monocular.py └── prepare_zjumocap.py └── train_net.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .idea/ 3 | .ipynb_checkpoints/ 4 | *.py[cod] 5 | *.so 6 | *.orig 7 | *.o 8 | *.json 9 | *.pth 10 | *.npy 11 | *.ipynb 12 | /data/ 13 | exps/ 14 | debug/ 15 | *.png 16 | exps 17 | *.ply 18 | *.obj 19 | ._.DS_Store 20 | .DS_Store 21 | exps* 22 | data/ 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | //////////////////////////////////////////////////////////////////////////// 2 | // Copyright 2022-2023 the 3D Vision Group at the State Key Lab of CAD&CG, 3 | // Zhejiang University. All Rights Reserved. 4 | // 5 | // For more information see 6 | // If you use this code, please cite the corresponding publications as 7 | // listed on the above website. 8 | // 9 | // Permission to use, copy, modify and distribute this software and its 10 | // documentation for educational, research and non-profit purposes only. 11 | // Any modification based on this work must be open source and prohibited 12 | // for commercial use. 13 | // You must retain, in the source form of any derivative works that you 14 | // distribute, all copyright, patent, trademark, and attribution notices 15 | // from the source form of this work. 16 | // 17 | // 18 | //////////////////////////////////////////////////////////////////////////// 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Neural Volumetric Representations of Dynamic Humans in Minutes 2 | 3 | ### [Project Page](https://zju3dv.github.io/instant_nvr) | [Video](https://zju3dv.github.io/instant_nvr) | [Paper](https://openaccess.thecvf.com/content/CVPR2023/papers/Geng_Learning_Neural_Volumetric_Representations_of_Dynamic_Humans_in_Minutes_CVPR_2023_paper.pdf) | [Data](https://github.com/zju3dv/instant-nvr/blob/master/docs/install.md#set-up-datasets) 4 | 5 | ![inb](docs/media/inb.gif) 6 | 7 | > [Learning Neural Volumetric Representations of Dynamic Humans in Minutes](https://zju3dv.github.io/instant_nvr) 8 | > 9 | > Chen Geng\*, Sida Peng\*, Zhen Xu\*, Hujun Bao, Xiaowei Zhou (* denotes equal contribution) 10 | > 11 | > CVPR 2023 12 | 13 | ## Installation 14 | 15 | See [here](./docs/install.md). 16 | 17 | ## Reproducing results in the paper 18 | 19 | We provide two scripts to help reproduce the results shown in the paper. 20 | 21 | After installing the environment and the dataset, for evaluation on the ZJU-MoCap dataset, run: 22 | 23 | ```shell 24 | sh scripts/eval_zjumocap.sh 25 | ``` 26 | 27 | For evaluation on the MonoCap dataset, run: 28 | 29 | ```shell 30 | sh scripts/eval_monocap.sh 31 | ``` 32 | 33 | 34 | ## Evaluation on ZJU-MoCap 35 | 36 | Let's take "377" as an example. 37 | 38 | Training on ZJU-MoCap can be done by running. 39 | 40 | ```shell 41 | export name=377 42 | python train_net.py --cfg_file configs/inb/inb_${name}.yaml exp_name inb_${name} gpus ${GPUS} 43 | ``` 44 | 45 | Evaluation can be done by running: 46 | ```shell 47 | export name=377 48 | python run.py --type evaluate --cfg_file configs/inb/inb_${name}.yaml exp_name inb_${name} gpus ${GPUS} 49 | ``` 50 | 51 | ## Evaluation on MonoCap 52 | 53 | Let's take "lan" as an example. 54 | 55 | Training on Monocap can be done by running: 56 | 57 | ```shell 58 | export name=lan 59 | python train_net.py --cfg_file configs/inb/inb_${name}.yaml exp_name inb_${name} gpus ${GPUS} 60 | ``` 61 | 62 | Evaluation can be done by running: 63 | ```shell 64 | export name=lan 65 | python run.py --type evaluate --cfg_file configs/inb/inb_${name}.yaml exp_name inb_${name} gpus ${GPUS} 66 | ``` 67 | 68 | ## TODO List 69 | 70 | This repository currently serves as the release of the technical paper's implementation and will undergo future updates (planned below) to enhance user-friendliness. We warmly welcome and appreciate any contributions. 71 | 72 | - [x] Instruction on running on custom datasets (Kudos to [@tian42chen](https://github.com/tian42chen)!!) 73 | - [ ] Add support for further acceleration using CUDA 74 | - [ ] Add a Google Colab notebook demo 75 | 76 | ## Bibtex 77 | 78 | If you find the repo useful for your research, please consider citing our paper: 79 | 80 | ``` 81 | @inproceedings{instant_nvr, 82 | title={Learning Neural Volumetric Representations of Dynamic Humans in Minutes}, 83 | author={Chen Geng and Sida Peng and Zhen Xu and Hujun Bao and Xiaowei Zhou}, 84 | booktitle={CVPR}, 85 | year={2023} 86 | } 87 | ``` 88 | -------------------------------------------------------------------------------- /configs/default.yml: -------------------------------------------------------------------------------- 1 | task: "monocular" 2 | gpus: [0] 3 | 4 | zju_human: "" 5 | 6 | train_dataset_module: "lib.datasets.h36m.tpose_dataset" 7 | val_dataset_module: "lib.datasets.h36m.tpose_dataset" 8 | test_dataset_module: "lib.datasets.h36m.tpose_dataset" 9 | prune_dataset_module: "lib.datasets.h36m.tpose_dataset" 10 | 11 | network_module: "lib.networks.bw_deform.inb_part_network_multiassign" 12 | renderer_module: "lib.networks.renderer.inb_renderer" 13 | 14 | renderer_vis_module: "lib.networks.renderer.inb_renderer" 15 | trainer_module: "lib.train.trainers.inb_trainer" 16 | evaluator_module: "lib.evaluators.if_nerf" 17 | visualizer_module: "lib.visualizers.if_nerf" 18 | 19 | network: 20 | occ: 21 | d_hidden: 64 22 | n_layers: 1 23 | color: 24 | d_hidden: 64 25 | n_layers: 2 26 | 27 | viewdir_embedder: 28 | module: "lib.networks.embedders.freq_embedder" 29 | kwargs: 30 | res: 4 31 | input_dims: 3 32 | 33 | # training options 34 | train_th: 0.1 35 | norm_th: 0.1 36 | 37 | # rendering options 38 | i_embed: 0 39 | xyz_res: 10 40 | view_res: 4 41 | raw_noise_std: 0 42 | 43 | N_samples: 64 44 | N_importance: 128 45 | N_rand: 1024 46 | 47 | perturb: 1 48 | white_bkgd: False 49 | 50 | render_views: 4 51 | 52 | # data options 53 | ratio: 0.5 54 | 55 | # partnet config 56 | tpose_deformer: 57 | module: "lib.networks.deformers.uv_deformer" 58 | embedder: 59 | module: "lib.networks.embedders.part_base_embedder" 60 | kwargs: 61 | n_levels: 8 62 | n_features_per_level: 2 63 | log2_hashmap_size: 14 64 | base_resolution: 4 65 | b: 1.38 66 | sum: False 67 | sum_over_features: True 68 | separate_dense: True 69 | use_batch_bounds: False 70 | include_input: True 71 | 72 | partnet: 73 | body: 74 | module: "lib.networks.bw_deform.part_base_network" 75 | embedder: 76 | module: "lib.networks.embedders.part_base_embedder" 77 | kwargs: 78 | n_levels: 16 79 | n_features_per_level: 16 80 | log2_hashmap_size: 20 81 | base_resolution: 16 82 | b: 1.38 83 | sum: True 84 | sum_over_features: True 85 | separate_dense: True 86 | use_batch_bounds: True 87 | bbox: [[-1, -1.2, -0.34], [0.8, 0.7, 0.5]] 88 | # pretrained: 'exps/inb/inb_377_body1/trained_model/body.pth' 89 | leg: 90 | module: "lib.networks.bw_deform.part_base_network" 91 | embedder: 92 | module: "lib.networks.embedders.part_base_embedder" 93 | kwargs: 94 | n_levels: 16 95 | n_features_per_level: 16 96 | log2_hashmap_size: 20 97 | base_resolution: 2 98 | b: 1.38 99 | sum: True 100 | sum_over_features: True 101 | separate_dense: True 102 | use_batch_bounds: True 103 | color_network: 104 | module: "lib.networks.bw_deform.part_base_network" 105 | kwargs: 106 | d_hidden: 64 107 | n_layers: 1 108 | bbox: [[-1, -1.2, -0.34], [0.8, -0.3, 0.5]] 109 | head: 110 | module: "lib.networks.bw_deform.part_base_network" 111 | embedder: 112 | module: "lib.networks.embedders.part_base_embedder" 113 | kwargs: 114 | n_levels: 16 115 | n_features_per_level: 16 116 | log2_hashmap_size: 18 117 | base_resolution: 2 118 | b: 1.38 119 | sum: True 120 | sum_over_features: True 121 | separate_dense: True 122 | use_batch_bounds: True 123 | bbox: [[-0.3, 0.3, -0.3], [0.3, 0.7, 0.3]] 124 | larm: 125 | module: "lib.networks.bw_deform.part_base_network" 126 | embedder: 127 | module: "lib.networks.embedders.part_base_embedder" 128 | kwargs: 129 | n_levels: 16 130 | n_features_per_level: 16 131 | log2_hashmap_size: 15 132 | base_resolution: 2 133 | b: 1.38 134 | sum: True 135 | sum_over_features: True 136 | separate_dense: True 137 | use_batch_bounds: True 138 | color_network: 139 | module: "lib.networks.bw_deform.part_base_network" 140 | kwargs: 141 | d_hidden: 64 142 | n_layers: 1 143 | bbox: [[0.2, 0, -0.2], [0.9, 0.35, 0.2]] 144 | rarm: 145 | module: "lib.networks.bw_deform.part_base_network" 146 | embedder: 147 | module: "lib.networks.embedders.part_base_embedder" 148 | kwargs: 149 | n_levels: 16 150 | n_features_per_level: 16 151 | log2_hashmap_size: 15 152 | base_resolution: 2 153 | b: 1.38 154 | sum: True 155 | sum_over_features: True 156 | separate_dense: True 157 | use_batch_bounds: True 158 | color_network: 159 | module: "lib.networks.bw_deform.part_base_network" 160 | kwargs: 161 | d_hidden: 64 162 | n_layers: 1 163 | bbox: [[-0.9, 0, -0.2], [-0.2, 0.35, 0.2]] 164 | 165 | rgb_resd_loss_coe: 0.01 166 | 167 | train: 168 | batch_size: 1 169 | collator: "" 170 | lr: 5e-4 171 | eps: 1e-15 172 | weight_decay: 0 173 | epoch: 6 174 | scheduler: 175 | type: "exponential" 176 | gamma: 0.1 177 | decay_epochs: 1000 178 | num_workers: 16 179 | 180 | test: 181 | sampler: "FrameSampler" 182 | batch_size: 1 183 | frame_sampler_interval: 4 184 | collator: "" 185 | 186 | val: 187 | frame_sampler_interval: 4 188 | 189 | eval_ep: 10 190 | save_latest_ep: 5 191 | save_ep: 400 192 | ep_iter: 500 193 | vis_ep: 100 194 | 195 | use_lpips: True 196 | use_time_embedder: True 197 | use_reg_distortion: True 198 | 199 | training_stages: 200 | - ratio: 0.3 201 | _start: 0 202 | - ratio: 0.5 203 | sample_focus: head 204 | _start: 2 205 | - ratio: 0.5 206 | sample_focus: "" 207 | reg_dist_weight: 1.0 208 | _start: 4 209 | 210 | smpl: smpl 211 | lbs: smpl_lbs 212 | params: smpl_params 213 | vertices: smpl_vertices 214 | 215 | part_deform: False 216 | patch_size: 64 217 | 218 | log_interval: 100 219 | 220 | train_dataset: 221 | data_root: "data/default" 222 | human: "" 223 | ann_file: "data/default/annots.npy" 224 | split: "train" 225 | 226 | val_dataset: 227 | data_root: "data/default" 228 | human: "" 229 | ann_file: "data/default/annots.npy" 230 | split: "val" 231 | 232 | test_dataset: 233 | data_root: "data/default" 234 | human: "" 235 | ann_file: "data/default/annots.npy" 236 | split: "test" 237 | 238 | bullet: 239 | dataset_module: "lib.datasets.h36m.tpose_novel_view_dataset" 240 | dataset_kwargs: 241 | data_root: "data/default" 242 | ann_file: "data/default/annots.npy" 243 | human: "" 244 | split: "test" 245 | visualizer_module: "lib.visualizers.if_nerf_demo" 246 | 247 | tmesh: 248 | dataset_module: 'lib.datasets.h36m.tpose_dataset_mesh' 249 | dataset_kwargs: 250 | data_root: 'data/default' 251 | human: '' 252 | ann_file: 'data/default/annots.npy' 253 | split: 'tmesh' 254 | renderer_module: 'lib.networks.renderer.pose_mesh_renderer' 255 | 256 | # data options 257 | training_view: [0] 258 | test_view: [0] 259 | begin_ith_frame: 0 260 | num_train_frame: 200 261 | frame_interval: 1 262 | 263 | smpl_thresh: 0.05 264 | exp_name: default 265 | pair_loss_weight: 10.0 266 | 267 | eval_ratio: 0.5 268 | silent: False 269 | -------------------------------------------------------------------------------- /configs/inb/inb_377.yaml: -------------------------------------------------------------------------------- 1 | task: "inb" 2 | gpus: [0] 3 | 4 | zju_human: "" 5 | 6 | train_dataset_module: "lib.datasets.h36m.tpose_dataset" 7 | val_dataset_module: "lib.datasets.h36m.tpose_dataset" 8 | test_dataset_module: "lib.datasets.h36m.tpose_dataset" 9 | prune_dataset_module: "lib.datasets.h36m.tpose_dataset" 10 | 11 | network_module: "lib.networks.bw_deform.inb_part_network_multiassign" 12 | renderer_module: "lib.networks.renderer.inb_renderer" 13 | 14 | renderer_vis_module: "lib.networks.renderer.inb_renderer" 15 | trainer_module: "lib.train.trainers.inb_trainer" 16 | evaluator_module: "lib.evaluators.if_nerf" 17 | visualizer_module: "lib.visualizers.if_nerf" 18 | 19 | network: 20 | occ: 21 | d_hidden: 64 22 | n_layers: 1 23 | color: 24 | d_hidden: 64 25 | n_layers: 2 26 | 27 | viewdir_embedder: 28 | module: "lib.networks.embedders.freq_embedder" 29 | kwargs: 30 | res: 4 31 | input_dims: 3 32 | 33 | # training options 34 | train_th: 0.1 35 | norm_th: 0.1 36 | 37 | # rendering options 38 | i_embed: 0 39 | xyz_res: 10 40 | view_res: 4 41 | raw_noise_std: 0 42 | 43 | N_samples: 64 44 | N_importance: 128 45 | N_rand: 1024 46 | 47 | perturb: 1 48 | white_bkgd: False 49 | 50 | render_views: 50 51 | 52 | # data options 53 | ratio: 0.5 54 | 55 | # partnet config 56 | tpose_deformer: 57 | module: "lib.networks.deformers.uv_deformer" 58 | embedder: 59 | module: "lib.networks.embedders.part_base_embedder" 60 | kwargs: 61 | n_levels: 8 62 | n_features_per_level: 2 63 | log2_hashmap_size: 14 64 | base_resolution: 4 65 | b: 1.38 66 | sum: False 67 | sum_over_features: True 68 | separate_dense: True 69 | use_batch_bounds: False 70 | include_input: True 71 | 72 | partnet: 73 | body: 74 | module: "lib.networks.bw_deform.part_base_network" 75 | embedder: 76 | module: "lib.networks.embedders.part_base_embedder" 77 | kwargs: 78 | n_levels: 16 79 | n_features_per_level: 16 80 | log2_hashmap_size: 20 81 | base_resolution: 16 82 | b: 1.38 83 | sum: True 84 | sum_over_features: True 85 | separate_dense: True 86 | use_batch_bounds: True 87 | bbox: [[-1, -1.2, -0.34], [0.8, 0.7, 0.5]] 88 | # pretrained: 'exps/inb/inb_377_body1/trained_model/body.pth' 89 | leg: 90 | module: "lib.networks.bw_deform.part_base_network" 91 | embedder: 92 | module: "lib.networks.embedders.part_base_embedder" 93 | kwargs: 94 | n_levels: 16 95 | n_features_per_level: 16 96 | log2_hashmap_size: 20 97 | base_resolution: 2 98 | b: 1.38 99 | sum: True 100 | sum_over_features: True 101 | separate_dense: True 102 | use_batch_bounds: True 103 | color_network: 104 | module: "lib.networks.bw_deform.part_base_network" 105 | kwargs: 106 | d_hidden: 64 107 | n_layers: 1 108 | bbox: [[-1, -1.2, -0.34], [0.8, -0.3, 0.5]] 109 | head: 110 | module: "lib.networks.bw_deform.part_base_network" 111 | embedder: 112 | module: "lib.networks.embedders.part_base_embedder" 113 | kwargs: 114 | n_levels: 16 115 | n_features_per_level: 16 116 | log2_hashmap_size: 18 117 | base_resolution: 2 118 | b: 1.38 119 | sum: True 120 | sum_over_features: True 121 | separate_dense: True 122 | use_batch_bounds: True 123 | bbox: [[-0.3, 0.3, -0.3], [0.3, 0.7, 0.3]] 124 | larm: 125 | module: "lib.networks.bw_deform.part_base_network" 126 | embedder: 127 | module: "lib.networks.embedders.part_base_embedder" 128 | kwargs: 129 | n_levels: 16 130 | n_features_per_level: 16 131 | log2_hashmap_size: 15 132 | base_resolution: 2 133 | b: 1.38 134 | sum: True 135 | sum_over_features: True 136 | separate_dense: True 137 | use_batch_bounds: True 138 | color_network: 139 | module: "lib.networks.bw_deform.part_base_network" 140 | kwargs: 141 | d_hidden: 64 142 | n_layers: 1 143 | bbox: [[0.2, 0, -0.2], [0.9, 0.35, 0.2]] 144 | rarm: 145 | module: "lib.networks.bw_deform.part_base_network" 146 | embedder: 147 | module: "lib.networks.embedders.part_base_embedder" 148 | kwargs: 149 | n_levels: 16 150 | n_features_per_level: 16 151 | log2_hashmap_size: 15 152 | base_resolution: 2 153 | b: 1.38 154 | sum: True 155 | sum_over_features: True 156 | separate_dense: True 157 | use_batch_bounds: True 158 | color_network: 159 | module: "lib.networks.bw_deform.part_base_network" 160 | kwargs: 161 | d_hidden: 64 162 | n_layers: 1 163 | bbox: [[-0.9, 0, -0.2], [-0.2, 0.35, 0.2]] 164 | 165 | rgb_resd_loss_coe: 0.01 166 | 167 | train: 168 | batch_size: 1 169 | collator: "" 170 | lr: 5e-4 171 | eps: 1e-15 172 | weight_decay: 0 173 | epoch: 6 174 | scheduler: 175 | type: "exponential" 176 | gamma: 0.1 177 | decay_epochs: 1000 178 | num_workers: 16 179 | 180 | test: 181 | sampler: "FrameSampler" 182 | batch_size: 1 183 | frame_sampler_interval: 10 184 | collator: "" 185 | frame_sampler_interval: 6 186 | 187 | val: 188 | frame_sampler_interval: 20 189 | 190 | eval_ep: 10 191 | save_latest_ep: 5 192 | save_ep: 400 193 | ep_iter: 500 194 | vis_ep: 100 195 | 196 | use_lpips: True 197 | use_time_embedder: True 198 | use_reg_distortion: True 199 | 200 | training_stages: 201 | - ratio: 0.3 202 | _start: 0 203 | - ratio: 0.5 204 | sample_focus: head 205 | _start: 2 206 | - ratio: 0.5 207 | sample_focus: "" 208 | reg_dist_weight: 1.0 209 | _start: 4 210 | 211 | smpl: smpl 212 | lbs: smpl_lbs 213 | params: smpl_params 214 | vertices: smpl_vertices 215 | 216 | part_deform: False 217 | patch_size: 64 218 | 219 | log_interval: 100 220 | 221 | train_dataset: 222 | data_root: "data/zju-mocap/my_377" 223 | human: "my_377" 224 | ann_file: "data/zju-mocap/my_377/annots.npy" 225 | split: "train" 226 | 227 | val_dataset: 228 | data_root: "data/zju-mocap/my_377" 229 | human: "my_377" 230 | ann_file: "data/zju-mocap/my_377/annots.npy" 231 | split: "val" 232 | 233 | test_dataset: 234 | data_root: "data/zju-mocap/my_377" 235 | human: "my_377" 236 | ann_file: "data/zju-mocap/my_377/annots.npy" 237 | split: "test" 238 | 239 | bullet: 240 | dataset_module: "lib.datasets.h36m.tpose_novel_view_dataset" 241 | dataset_kwargs: 242 | data_root: "data/zju-mocap/my_377" 243 | ann_file: "data/zju-mocap/my_377/annots.npy" 244 | human: "my_377" 245 | split: "test" 246 | visualizer_module: "lib.visualizers.if_nerf_demo" 247 | 248 | tmesh: 249 | dataset_module: 'lib.datasets.h36m.tpose_dataset_mesh' 250 | dataset_kwargs: 251 | data_root: 'data/zju-mocap/my_377' 252 | human: 'my_377' 253 | ann_file: 'data/zju-mocap/my_377/annots.npy' 254 | split: 'tmesh' 255 | renderer_module: 'lib.networks.renderer.pose_mesh_renderer' 256 | 257 | # data options 258 | training_view: [4] 259 | test_view: [] 260 | begin_ith_frame: 0 261 | num_train_frame: 100 262 | frame_interval: 5 263 | 264 | smpl_thresh: 0.05 265 | exp_name: inb_377 266 | pair_loss_weight: 10.0 267 | 268 | eval_ratio: 0.5 269 | silent: False 270 | -------------------------------------------------------------------------------- /configs/inb/inb_386.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: "configs/inb/inb_377.yaml" 2 | 3 | train_dataset: 4 | data_root: "data/zju-mocap/my_386" 5 | human: "my_386" 6 | ann_file: "data/zju-mocap/my_386/annots.npy" 7 | 8 | val_dataset: 9 | data_root: "data/zju-mocap/my_386" 10 | human: "my_386" 11 | ann_file: "data/zju-mocap/my_386/annots.npy" 12 | 13 | test_dataset: 14 | data_root: "data/zju-mocap/my_386" 15 | human: "my_386" 16 | ann_file: "data/zju-mocap/my_386/annots.npy" 17 | 18 | bullet: 19 | dataset_kwargs: 20 | data_root: "data/zju-mocap/my_386" 21 | ann_file: "data/zju-mocap/my_386/annots.npy" 22 | human: "my_386" 23 | 24 | # data options 25 | training_view: [4] 26 | test_view: [] 27 | begin_ith_frame: 0 28 | num_train_frame: 100 29 | frame_interval: 5 30 | 31 | test: 32 | frame_sampler_interval: 6 33 | 34 | smpl_thresh: 0.05 35 | exp_name: inb_386 36 | pair_loss_weight: 1e-6 37 | resd_loss_weight: 1e-2 38 | -------------------------------------------------------------------------------- /configs/inb/inb_387.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: "configs/inb/inb_377.yaml" 2 | 3 | train_dataset: 4 | data_root: "data/zju-mocap/my_387" 5 | human: "my_387" 6 | ann_file: "data/zju-mocap/my_387/annots.npy" 7 | 8 | val_dataset: 9 | data_root: "data/zju-mocap/my_387" 10 | human: "my_387" 11 | ann_file: "data/zju-mocap/my_387/annots.npy" 12 | 13 | test_dataset: 14 | data_root: "data/zju-mocap/my_387" 15 | human: "my_387" 16 | ann_file: "data/zju-mocap/my_387/annots.npy" 17 | 18 | bullet: 19 | dataset_kwargs: 20 | data_root: "data/zju-mocap/my_387" 21 | ann_file: "data/zju-mocap/my_387/annots.npy" 22 | human: "my_387" 23 | 24 | # data options 25 | training_view: [4] 26 | test_view: [] 27 | begin_ith_frame: 0 28 | num_train_frame: 100 29 | frame_interval: 5 30 | 31 | test: 32 | frame_sampler_interval: 6 33 | 34 | smpl_thresh: 0.05 35 | exp_name: inb_387 36 | pair_loss_weight: 10.0 -------------------------------------------------------------------------------- /configs/inb/inb_390.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: "configs/inb/inb_377.yaml" 2 | 3 | train_dataset: 4 | data_root: "data/zju-mocap/my_390" 5 | human: "my_390" 6 | ann_file: "data/zju-mocap/my_390/annots.npy" 7 | 8 | val_dataset: 9 | data_root: "data/zju-mocap/my_390" 10 | human: "my_390" 11 | ann_file: "data/zju-mocap/my_390/annots.npy" 12 | 13 | test_dataset: 14 | data_root: "data/zju-mocap/my_390" 15 | human: "my_390" 16 | ann_file: "data/zju-mocap/my_390/annots.npy" 17 | 18 | bullet: 19 | dataset_kwargs: 20 | data_root: "data/zju-mocap/my_390" 21 | ann_file: "data/zju-mocap/my_390/annots.npy" 22 | human: "my_390" 23 | 24 | # data options 25 | training_view: [4] 26 | test_view: [] 27 | begin_ith_frame: 0 28 | num_train_frame: 100 29 | frame_interval: 5 30 | 31 | test: 32 | frame_sampler_interval: 6 33 | 34 | smpl_thresh: 0.05 35 | exp_name: inb_390 36 | pair_loss_weight: 1e-4 -------------------------------------------------------------------------------- /configs/inb/inb_392.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: "configs/inb/inb_377.yaml" 2 | 3 | train_dataset: 4 | data_root: "data/zju-mocap/my_392" 5 | human: "my_392" 6 | ann_file: "data/zju-mocap/my_392/annots.npy" 7 | 8 | val_dataset: 9 | data_root: "data/zju-mocap/my_392" 10 | human: "my_392" 11 | ann_file: "data/zju-mocap/my_392/annots.npy" 12 | 13 | test_dataset: 14 | data_root: "data/zju-mocap/my_392" 15 | human: "my_392" 16 | ann_file: "data/zju-mocap/my_392/annots.npy" 17 | 18 | bullet: 19 | dataset_kwargs: 20 | data_root: "data/zju-mocap/my_392" 21 | ann_file: "data/zju-mocap/my_392/annots.npy" 22 | human: "my_392" 23 | 24 | # data options 25 | training_view: [4] 26 | test_view: [] 27 | begin_ith_frame: 0 28 | num_train_frame: 100 29 | frame_interval: 5 30 | 31 | test: 32 | frame_sampler_interval: 6 33 | 34 | smpl_thresh: 0.05 35 | exp_name: inb_392 36 | pair_loss_weight: 10.0 -------------------------------------------------------------------------------- /configs/inb/inb_393.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: "configs/inb/inb_377.yaml" 2 | 3 | train_dataset: 4 | data_root: "data/zju-mocap/my_393" 5 | human: "my_393" 6 | ann_file: "data/zju-mocap/my_393/annots.npy" 7 | 8 | val_dataset: 9 | data_root: "data/zju-mocap/my_393" 10 | human: "my_393" 11 | ann_file: "data/zju-mocap/my_393/annots.npy" 12 | 13 | test_dataset: 14 | data_root: "data/zju-mocap/my_393" 15 | human: "my_393" 16 | ann_file: "data/zju-mocap/my_393/annots.npy" 17 | 18 | bullet: 19 | dataset_kwargs: 20 | data_root: "data/zju-mocap/my_393" 21 | ann_file: "data/zju-mocap/my_393/annots.npy" 22 | human: "my_393" 23 | 24 | # data options 25 | training_view: [4] 26 | test_view: [] 27 | begin_ith_frame: 0 28 | num_train_frame: 100 29 | frame_interval: 5 30 | 31 | test: 32 | frame_sampler_interval: 6 33 | 34 | smpl_thresh: 0.05 35 | exp_name: inb_393 36 | pair_loss_weight: 10.0 -------------------------------------------------------------------------------- /configs/inb/inb_394.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: "configs/inb/inb_377.yaml" 2 | 3 | train_dataset: 4 | data_root: "data/zju-mocap/my_394" 5 | human: "my_394" 6 | ann_file: "data/zju-mocap/my_394/annots.npy" 7 | 8 | val_dataset: 9 | data_root: "data/zju-mocap/my_394" 10 | human: "my_394" 11 | ann_file: "data/zju-mocap/my_394/annots.npy" 12 | 13 | test_dataset: 14 | data_root: "data/zju-mocap/my_394" 15 | human: "my_394" 16 | ann_file: "data/zju-mocap/my_394/annots.npy" 17 | 18 | bullet: 19 | dataset_kwargs: 20 | data_root: "data/zju-mocap/my_394" 21 | ann_file: "data/zju-mocap/my_394/annots.npy" 22 | human: "my_394" 23 | 24 | # data options 25 | training_view: [4] 26 | test_view: [] 27 | begin_ith_frame: 0 28 | num_train_frame: 100 29 | frame_interval: 5 30 | 31 | test: 32 | frame_sampler_interval: 6 33 | 34 | smpl_thresh: 0.05 35 | exp_name: inb_394 36 | pair_loss_weight: 10.0 -------------------------------------------------------------------------------- /configs/inb/inb_lan.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: "configs/inb/inb_377.yaml" 2 | 3 | train_dataset: 4 | data_root: "data/monocap/lan_images620_1300" 5 | human: "lan_images620_1300" 6 | ann_file: "data/monocap/lan_images620_1300/annots.npy" 7 | 8 | val_dataset: 9 | data_root: "data/monocap/lan_images620_1300" 10 | human: "lan_images620_1300" 11 | ann_file: "data/monocap/lan_images620_1300/annots.npy" 12 | 13 | test_dataset: 14 | data_root: "data/monocap/lan_images620_1300" 15 | human: "lan_images620_1300" 16 | ann_file: "data/monocap/lan_images620_1300/annots.npy" 17 | 18 | bullet: 19 | dataset_kwargs: 20 | data_root: "data/monocap/lan_images620_1300" 21 | ann_file: "data/monocap/lan_images620_1300/annots.npy" 22 | human: "lan_images620_1300" 23 | 24 | # data options 25 | training_view: [0] 26 | test_view: [] 27 | begin_ith_frame: 0 28 | num_train_frame: 100 29 | frame_interval: 5 30 | erode_edge: False 31 | ratio: 1.0 32 | 33 | training_stages: 34 | - ratio: 1.0 35 | _start: 0 36 | 37 | train: 38 | lr: 8e-4 39 | 40 | test: 41 | frame_sampler_interval: 6 42 | 43 | smpl: "smpl" 44 | lbs: "lbs" 45 | params: "params" 46 | vertices: "vertices" 47 | 48 | smpl_thresh: 0.1 49 | exp_name: inb_lan 50 | pair_loss_weight: 1e-4 51 | 52 | train: 53 | lr: 1e-3 54 | 55 | eval_ratio: 1.0 56 | 57 | training_stages: 58 | - ratio: 0.3 59 | _start: 0 60 | - ratio: 0.5 61 | sample_focus: head 62 | _start: 2 63 | - ratio: 0.5 64 | sample_focus: "" 65 | reg_dist_weight: 1.0 66 | _start: 4 67 | -------------------------------------------------------------------------------- /configs/inb/inb_marc.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: "configs/inb/inb_377.yaml" 2 | 3 | train_dataset: 4 | data_root: "data/monocap/marc_images35000_36200" 5 | human: "marc_images35000_36200" 6 | ann_file: "data/monocap/marc_images35000_36200/annots.npy" 7 | 8 | val_dataset: 9 | data_root: "data/monocap/marc_images35000_36200" 10 | human: "marc_images35000_36200" 11 | ann_file: "data/monocap/marc_images35000_36200/annots.npy" 12 | 13 | test_dataset: 14 | data_root: "data/monocap/marc_images35000_36200" 15 | human: "marc_images35000_36200" 16 | ann_file: "data/monocap/marc_images35000_36200/annots.npy" 17 | 18 | bullet: 19 | dataset_kwargs: 20 | data_root: "data/monocap/marc_images35000_36200" 21 | ann_file: "data/monocap/marc_images35000_36200/annots.npy" 22 | human: "marc_images35000_36200" 23 | 24 | # data options 25 | training_view: [0] 26 | test_view: [] 27 | begin_ith_frame: 0 28 | num_train_frame: 100 29 | frame_interval: 5 30 | erode_edge: False 31 | ratio: 1.0 32 | 33 | # resd_loss_weight: 1.0 34 | 35 | # training_stages: 36 | # - ratio: 0.6 37 | # _start: 0 38 | # - ratio: 1.0 39 | # _start: 2 40 | # sample_focus: head 41 | # - ratio: 1.0 42 | # _start: 3 43 | # sample_focus: "" 44 | 45 | train: 46 | lr: 1.2e-3 47 | 48 | # use_pair_reg: False 49 | test: 50 | frame_sampler_interval: 6 51 | 52 | smpl: "smpl" 53 | lbs: "lbs" 54 | params: "params" 55 | vertices: "vertices" 56 | 57 | smpl_thresh: 0.1 58 | exp_name: inb_marc 59 | pair_loss_weight: 1e-4 60 | 61 | train: 62 | lr: 1e-3 63 | 64 | eval_ratio: 1.0 65 | 66 | training_stages: 67 | - ratio: 0.3 68 | _start: 0 69 | - ratio: 0.5 70 | sample_focus: head 71 | _start: 2 72 | - ratio: 0.5 73 | sample_focus: "" 74 | reg_dist_weight: 1.0 75 | _start: 4 76 | -------------------------------------------------------------------------------- /configs/inb/inb_olek.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: "configs/inb/inb_377.yaml" 2 | 3 | train_dataset: 4 | data_root: "data/monocap/olek_images0812" 5 | human: "olek_images0812" 6 | ann_file: "data/monocap/olek_images0812/annots.npy" 7 | 8 | val_dataset: 9 | data_root: "data/monocap/olek_images0812" 10 | human: "olek_images0812" 11 | ann_file: "data/monocap/olek_images0812/annots.npy" 12 | 13 | test_dataset: 14 | data_root: "data/monocap/olek_images0812" 15 | human: "olek_images0812" 16 | ann_file: "data/monocap/olek_images0812/annots.npy" 17 | 18 | bullet: 19 | dataset_kwargs: 20 | data_root: "data/monocap/olek_images0812" 21 | ann_file: "data/monocap/olek_images0812/annots.npy" 22 | human: "olek_images0812" 23 | 24 | # data options 25 | training_view: [44] 26 | test_view: [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 49] 27 | begin_ith_frame: 0 28 | num_train_frame: 100 29 | frame_interval: 5 30 | erode_edge: False 31 | ratio: 1.0 32 | 33 | # training_stages: 34 | # - ratio: 1.0 35 | # _start: 0 36 | 37 | # train: 38 | # lr: 8e-4 39 | 40 | test: 41 | frame_sampler_interval: 6 42 | 43 | smpl: "smpl" 44 | lbs: "lbs" 45 | params: "params" 46 | vertices: "vertices" 47 | 48 | smpl_thresh: 0.05 49 | exp_name: inb_olek 50 | pair_loss_weight: 10.0 51 | 52 | train: 53 | lr: 1e-3 54 | 55 | eval_ratio: 1.0 56 | 57 | training_stages: 58 | - ratio: 0.3 59 | _start: 0 60 | - ratio: 0.5 61 | sample_focus: head 62 | _start: 2 63 | - ratio: 0.5 64 | sample_focus: "" 65 | reg_dist_weight: 1.0 66 | _start: 4 67 | -------------------------------------------------------------------------------- /configs/inb/inb_vlad.yaml: -------------------------------------------------------------------------------- 1 | parent_cfg: "configs/inb/inb_377.yaml" 2 | 3 | train_dataset: 4 | data_root: "data/monocap/vlad_images1011" 5 | human: "vlad_images1011" 6 | ann_file: "data/monocap/vlad_images1011/annots.npy" 7 | 8 | val_dataset: 9 | data_root: "data/monocap/vlad_images1011" 10 | human: "vlad_images1011" 11 | ann_file: "data/monocap/vlad_images1011/annots.npy" 12 | 13 | test_dataset: 14 | data_root: "data/monocap/vlad_images1011" 15 | human: "vlad_images1011" 16 | ann_file: "data/monocap/vlad_images1011/annots.npy" 17 | 18 | bullet: 19 | dataset_kwargs: 20 | data_root: "data/monocap/vlad_images1011" 21 | ann_file: "data/monocap/vlad_images1011/annots.npy" 22 | human: "vlad_images1011" 23 | 24 | # data options 25 | training_view: [66] 26 | test_view: [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100] 27 | begin_ith_frame: 0 28 | num_train_frame: 100 29 | frame_interval: 5 30 | erode_edge: False 31 | ratio: 1.0 32 | 33 | training_stages: 34 | - ratio: 1.0 35 | _start: 0 36 | 37 | train: 38 | lr: 8e-4 39 | 40 | test: 41 | frame_sampler_interval: 6 42 | 43 | smpl: "smpl" 44 | lbs: "lbs" 45 | params: "params" 46 | vertices: "vertices" 47 | 48 | smpl_thresh: 0.1 49 | exp_name: inb_vlad 50 | pair_loss_weight: 1e-4 51 | 52 | train: 53 | lr: 1e-3 54 | 55 | eval_ratio: 1.0 56 | 57 | training_stages: 58 | - ratio: 0.3 59 | _start: 0 60 | - ratio: 0.5 61 | sample_focus: head 62 | _start: 2 63 | - ratio: 0.5 64 | sample_focus: "" 65 | reg_dist_weight: 1.0 66 | _start: 4 67 | -------------------------------------------------------------------------------- /configs/monocular.yml: -------------------------------------------------------------------------------- 1 | task: "monocular" 2 | exp_name: "dance" 3 | gpus: [1] 4 | 5 | default_cfg_path: "configs/default.yml" 6 | models_path: 'data/models' 7 | smpl_path: 'data/smpl-meta' 8 | data_root: 'data/dance' 9 | 10 | tmp_path: '/mnt/data/' 11 | 12 | begin_ith_frame: 0 13 | num_train_frame: 100 14 | frame_interval: 3 15 | 16 | 17 | # 3rdparty 18 | easymocap_path: '~/opt/EasyMocap' 19 | schp_path: '~/opt/EasyMocap/3rdparty/Self-Correction-Human-Parsing' -------------------------------------------------------------------------------- /docs/install.md: -------------------------------------------------------------------------------- 1 | ### Set up the Python environment 2 | 3 | Initialize python environment by running: 4 | 5 | ```shell 6 | conda create -n instant-nvr python=3.9 7 | conda activate instant-nvr 8 | ``` 9 | 10 | Then, install pytorch3d=0.7.2 according to the [instructions here](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md). 11 | 12 | Finally, install other packages by running: 13 | 14 | ```shell 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | ### Set up datasets 19 | 20 | For both datasets, we refine the camera parameters. See below for further details. 21 | 22 | #### ZJU-MoCap dataset 23 | 24 | Since the dataset is licensed, we require the users to agree with some policies before obtaining the dataset. Anyone who wants access to the dataset can fill out this [form](https://forms.gle/rBSXkXpuSsHep4R26) to obtain the instructions to download the data. You can also alternatively fill in this [agreement](https://pengsida.net/project_page_assets/files/Refined_ZJU-MoCap_Agreement.pdf) and email it to [Chen Geng](mailto:chen.geng@cs.stanford.edu) with cc to [Sida Peng](mailto:pengsida@zju.edu.cn) and [Xiaowei Zhou](mailto:xwzhou@zju.edu.cn) to obtain the access. 25 | 26 | **Please note, even if you have previously downloaded the ZJU-MoCap dataset from our [previous work](https://github.com/zju3dv/neuralbody), it is essential to re-download it. We have refined the dataset to include more accurate camera parameters and additional auxiliary files that are crucial for running our code.** 27 | 28 | > If you've sent an email and have not received a response within three days, there's a possibility that your email may have been overlooked. We kindly request you resend the email as a reminder. 29 | 30 | After acquiring the link, set up the dataset by: 31 | 32 | ```shell 33 | ROOT=/path/to/instant-nvr 34 | mkdir -p $ROOT/data 35 | cd $ROOT/data 36 | ln -s /path/to/my-zjumocap zju-mocap 37 | ``` 38 | 39 | #### MonoCap dataset 40 | 41 | Following [animatable_nerf](https://github.com/zju3dv/animatable_nerf/blob/master/INSTALL.md#monocap-dataset), the dataset is composed by [DeepCap](https://people.mpi-inf.mpg.de/~mhaberma/projects/2020-cvpr-deepcap/) and [DynaCap](https://people.mpi-inf.mpg.de/~mhaberma/projects/2021-ddc/), which forbids further distribution. 42 | 43 | Please download the raw data [here](https://gvv-assets.mpi-inf.mpg.de/) and email [Chen Geng](mailto:chen.geng@cs.stanford.edu) with cc to [Sida Peng](mailto:pengsida@zju.edu.cn) for instructions on how to process this dataset. 44 | 45 | After successfully obtaining the dataset, set up it by: 46 | 47 | ```shell 48 | ROOT=/path/to/instant-nvr 49 | mkdir -p $ROOT/data 50 | cd $ROOT/data 51 | ln -s /path/to/monocap monocap 52 | ``` 53 | 54 | #### Custom Dataset 55 | 56 | We have recently uploaded our instruction for processing custom dataset [here](https://github.com/zju3dv/instant-nvr/blob/master/docs/preprocess.md). It is still in an early stage, and have not been fully tested yet. We need your feedback! Please try it and let us know if you have any questions. 57 | -------------------------------------------------------------------------------- /docs/media/inb.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zju3dv/instant-nvr/02decc423cc882deffee053cdbdee8b70c8285ec/docs/media/inb.gif -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: pytorch3d 2 | channels: 3 | - pytorch3d 4 | - iopath 5 | - pytorch 6 | - nvidia 7 | - conda-forge 8 | - defaults 9 | dependencies: 10 | - _libgcc_mutex=0.1=main 11 | - _openmp_mutex=5.1=1_gnu 12 | - blas=1.0=mkl 13 | - brotlipy=0.7.0=py39h27cfd23_1003 14 | - bzip2=1.0.8=h7b6447c_0 15 | - ca-certificates=2023.01.10=h06a4308_0 16 | - certifi=2022.12.7=py39h06a4308_0 17 | - cffi=1.15.1=py39h5eee18b_3 18 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 19 | - colorama=0.4.6=pyhd8ed1ab_0 20 | - cryptography=38.0.4=py39h9ce1e76_0 21 | - cuda=11.6.1=0 22 | - cuda-cccl=11.6.55=hf6102b2_0 23 | - cuda-command-line-tools=11.6.2=0 24 | - cuda-compiler=11.6.2=0 25 | - cuda-cudart=11.6.55=he381448_0 26 | - cuda-cudart-dev=11.6.55=h42ad0f4_0 27 | - cuda-cuobjdump=11.6.124=h2eeebcb_0 28 | - cuda-cupti=11.6.124=h86345e5_0 29 | - cuda-cuxxfilt=11.6.124=hecbf4f6_0 30 | - cuda-driver-dev=11.6.55=0 31 | - cuda-gdb=12.0.140=0 32 | - cuda-libraries=11.6.1=0 33 | - cuda-libraries-dev=11.6.1=0 34 | - cuda-memcheck=11.8.86=0 35 | - cuda-nsight=12.0.140=0 36 | - cuda-nsight-compute=12.0.1=0 37 | - cuda-nvcc=11.6.124=hbba6d2d_0 38 | - cuda-nvdisasm=12.0.140=0 39 | - cuda-nvml-dev=11.6.55=haa9ef22_0 40 | - cuda-nvprof=12.0.146=0 41 | - cuda-nvprune=11.6.124=he22ec0a_0 42 | - cuda-nvrtc=11.6.124=h020bade_0 43 | - cuda-nvrtc-dev=11.6.124=h249d397_0 44 | - cuda-nvtx=11.6.124=h0630a44_0 45 | - cuda-nvvp=12.0.146=0 46 | - cuda-runtime=11.6.1=0 47 | - cuda-samples=11.6.101=h8efea70_0 48 | - cuda-sanitizer-api=12.0.140=0 49 | - cuda-toolkit=11.6.1=0 50 | - cuda-tools=11.6.1=0 51 | - cuda-visual-tools=11.6.1=0 52 | - ffmpeg=4.3=hf484d3e_0 53 | - flit-core=3.6.0=pyhd3eb1b0_0 54 | - freetype=2.12.1=h4a9f257_0 55 | - fvcore=0.1.5.post20221221=pyhd8ed1ab_0 56 | - gds-tools=1.5.1.14=0 57 | - giflib=5.2.1=h5eee18b_1 58 | - gmp=6.2.1=h295c915_3 59 | - gnutls=3.6.15=he1e5248_0 60 | - idna=3.4=py39h06a4308_0 61 | - intel-openmp=2021.4.0=h06a4308_3561 62 | - iopath=0.1.9=py39 63 | - jpeg=9e=h7f8727e_0 64 | - lame=3.100=h7b6447c_0 65 | - lcms2=2.12=h3be6417_0 66 | - ld_impl_linux-64=2.38=h1181459_1 67 | - lerc=3.0=h295c915_0 68 | - libcublas=11.9.2.110=h5e84587_0 69 | - libcublas-dev=11.9.2.110=h5c901ab_0 70 | - libcufft=10.7.1.112=hf425ae0_0 71 | - libcufft-dev=10.7.1.112=ha5ce4c0_0 72 | - libcufile=1.5.1.14=0 73 | - libcufile-dev=1.5.1.14=0 74 | - libcurand=10.3.1.124=0 75 | - libcurand-dev=10.3.1.124=0 76 | - libcusolver=11.3.4.124=h33c3c4e_0 77 | - libcusparse=11.7.2.124=h7538f96_0 78 | - libcusparse-dev=11.7.2.124=hbbe9722_0 79 | - libdeflate=1.8=h7f8727e_5 80 | - libffi=3.4.2=h6a678d5_6 81 | - libgcc-ng=11.2.0=h1234567_1 82 | - libgomp=11.2.0=h1234567_1 83 | - libiconv=1.16=h7f8727e_2 84 | - libidn2=2.3.2=h7f8727e_0 85 | - libnpp=11.6.3.124=hd2722f0_0 86 | - libnpp-dev=11.6.3.124=h3c42840_0 87 | - libnvjpeg=11.6.2.124=hd473ad6_0 88 | - libnvjpeg-dev=11.6.2.124=hb5906b9_0 89 | - libpng=1.6.37=hbc83047_0 90 | - libstdcxx-ng=11.2.0=h1234567_1 91 | - libtasn1=4.16.0=h27cfd23_0 92 | - libtiff=4.5.0=h6a678d5_1 93 | - libunistring=0.9.10=h27cfd23_0 94 | - libwebp=1.2.4=h11a3e52_0 95 | - libwebp-base=1.2.4=h5eee18b_0 96 | - lz4-c=1.9.4=h6a678d5_0 97 | - mkl=2021.4.0=h06a4308_640 98 | - mkl-service=2.4.0=py39h7f8727e_0 99 | - mkl_fft=1.3.1=py39hd3c417c_0 100 | - mkl_random=1.2.2=py39h51133e4_0 101 | - ncurses=6.4=h6a678d5_0 102 | - nettle=3.7.3=hbbd107a_1 103 | - nsight-compute=2022.4.1.6=0 104 | - numpy=1.23.5=py39h14f4228_0 105 | - numpy-base=1.23.5=py39h31eccc5_0 106 | - openh264=2.1.1=h4ff587b_0 107 | - openssl=1.1.1t=h7f8727e_0 108 | - pillow=9.3.0=py39h6a678d5_2 109 | - pip=22.3.1=py39h06a4308_0 110 | - portalocker=2.7.0=py39hf3d152e_0 111 | - pycparser=2.21=pyhd3eb1b0_0 112 | - pyopenssl=22.0.0=pyhd3eb1b0_0 113 | - pysocks=1.7.1=py39h06a4308_0 114 | - python=3.9.16=h7a1cb2a_0 115 | - python_abi=3.9=2_cp39 116 | - pytorch=1.13.0=py3.9_cuda11.6_cudnn8.3.2_0 117 | - pytorch-cuda=11.6=h867d48c_1 118 | - pytorch-mutex=1.0=cuda 119 | - pytorch3d=0.7.2=py39_cu116_pyt1130 120 | - pyyaml=6.0=py39hb9d737c_4 121 | - readline=8.2=h5eee18b_0 122 | - requests=2.28.1=py39h06a4308_0 123 | - setuptools=65.6.3=py39h06a4308_0 124 | - six=1.16.0=pyhd3eb1b0_1 125 | - sqlite=3.40.1=h5082296_0 126 | - tabulate=0.9.0=pyhd8ed1ab_1 127 | - termcolor=2.2.0=pyhd8ed1ab_0 128 | - tk=8.6.12=h1ccaba5_0 129 | - torchvision=0.14.0=py39_cu116 130 | - tqdm=4.64.1=pyhd8ed1ab_0 131 | - typing_extensions=4.4.0=py39h06a4308_0 132 | - urllib3=1.26.14=py39h06a4308_0 133 | - wheel=0.38.4=py39h06a4308_0 134 | - xz=5.2.10=h5eee18b_1 135 | - yacs=0.1.8=pyhd8ed1ab_0 136 | - yaml=0.2.5=h7f98852_2 137 | - zlib=1.2.13=h5eee18b_0 138 | - zstd=1.5.2=ha4553b6_0 139 | - pip: 140 | - addict==2.4.0 141 | - asttokens==2.2.1 142 | - attrs==23.1.0 143 | - backcall==0.2.0 144 | - beartype==0.14.0 145 | - blenderproc==2.5.0 146 | - blinker==1.6.2 147 | - click==8.1.3 148 | - colored-traceback==0.3.0 149 | - comm==0.1.3 150 | - configargparse==1.5.3 151 | - contourpy==1.0.7 152 | - cycler==0.11.0 153 | - dash==2.9.3 154 | - dash-core-components==2.0.0 155 | - dash-html-components==2.0.0 156 | - dash-table==5.0.0 157 | - debugpy==1.6.7 158 | - decorator==5.1.1 159 | - executing==1.2.0 160 | - fastjsonschema==2.17.1 161 | - flask==2.3.2 162 | - fonttools==4.38.0 163 | - gitdb==4.0.10 164 | - gitpython==3.1.31 165 | - h5py==3.8.0 166 | - imageio==2.25.1 167 | - importlib-metadata==6.6.0 168 | - importlib-resources==5.12.0 169 | - ipdb==0.13.13 170 | - ipykernel==6.23.1 171 | - ipython==8.13.2 172 | - ipywidgets==8.0.6 173 | - itsdangerous==2.1.2 174 | - jedi==0.18.2 175 | - jinja2==3.1.2 176 | - joblib==1.2.0 177 | - jsonschema==4.17.3 178 | - jupyter-client==8.2.0 179 | - jupyter-core==5.3.0 180 | - jupyterlab-widgets==3.0.7 181 | - kiwisolver==1.4.4 182 | - lpips==0.1.4 183 | - markupsafe==2.1.2 184 | - matplotlib==3.7.0 185 | - matplotlib-inline==0.1.6 186 | - mpmath==1.2.1 187 | - nbformat==5.7.0 188 | - nest-asyncio==1.5.6 189 | - networkx==3.0 190 | - nptyping==2.5.0 191 | - open3d==0.17.0 192 | - opencv-python==4.7.0.68 193 | - packaging==23.0 194 | - pandas==2.0.1 195 | - parso==0.8.3 196 | - pexpect==4.8.0 197 | - pickleshare==0.7.5 198 | - platformdirs==3.5.1 199 | - plotly==5.14.1 200 | - plyfile==0.7.4 201 | - progress==1.6 202 | - progressbar==2.5 203 | - prompt-toolkit==3.0.38 204 | - protobuf==3.20.3 205 | - psutil==5.9.5 206 | - ptyprocess==0.7.0 207 | - pure-eval==0.2.2 208 | - pygments==2.14.0 209 | - pyparsing==3.0.9 210 | - pyquaternion==0.9.9 211 | - pyrsistent==0.19.3 212 | - python-dateutil==2.8.2 213 | - pytz==2023.3 214 | - pywavelets==1.4.1 215 | - pyzmq==25.0.2 216 | - scikit-image==0.19.3 217 | - scikit-learn==1.2.2 218 | - scipy==1.10.1 219 | - smmap==5.0.0 220 | - stack-data==0.6.2 221 | - sympy==1.11.1 222 | - tenacity==8.2.2 223 | - tensorboardx==2.6 224 | - threadpoolctl==3.1.0 225 | - tifffile==2023.2.3 226 | - tomli==2.0.1 227 | - tornado==6.3.2 228 | - traitlets==5.9.0 229 | - trimesh==3.20.0 230 | - tzdata==2023.3 231 | - wcwidth==0.2.6 232 | - werkzeug==2.3.4 233 | - widgetsnbextension==4.0.7 234 | - zipp==3.14.0 235 | prefix: /home/gengchen/miniconda3/envs/pytorch3d 236 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zju3dv/instant-nvr/02decc423cc882deffee053cdbdee8b70c8285ec/lib/__init__.py -------------------------------------------------------------------------------- /lib/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import cfg, args 2 | -------------------------------------------------------------------------------- /lib/csrc/pointnet2/pointnet2_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from . import pointnet2_utils 6 | from . import pytorch_utils as pt_utils 7 | from typing import List 8 | 9 | 10 | class _PointnetSAModuleBase(nn.Module): 11 | 12 | def __init__(self): 13 | super().__init__() 14 | self.npoint = None 15 | self.groupers = None 16 | self.mlps = None 17 | self.pool_method = 'max_pool' 18 | 19 | def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor): 20 | """ 21 | :param xyz: (B, N, 3) tensor of the xyz coordinates of the features 22 | :param features: (B, N, C) tensor of the descriptors of the the features 23 | :param new_xyz: 24 | :return: 25 | new_xyz: (B, npoint, 3) tensor of the new features' xyz 26 | new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors 27 | """ 28 | new_features_list = [] 29 | 30 | xyz_flipped = xyz.transpose(1, 2).contiguous() 31 | if new_xyz is None: 32 | new_xyz = pointnet2_utils.gather_operation( 33 | xyz_flipped, 34 | pointnet2_utils.furthest_point_sample(xyz, self.npoint) 35 | ).transpose(1, 2).contiguous() if self.npoint is not None else None 36 | 37 | for i in range(len(self.groupers)): 38 | new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample) 39 | 40 | new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample) 41 | if self.pool_method == 'max_pool': 42 | new_features = F.max_pool2d( 43 | new_features, kernel_size=[1, new_features.size(3)] 44 | ) # (B, mlp[-1], npoint, 1) 45 | elif self.pool_method == 'avg_pool': 46 | new_features = F.avg_pool2d( 47 | new_features, kernel_size=[1, new_features.size(3)] 48 | ) # (B, mlp[-1], npoint, 1) 49 | else: 50 | raise NotImplementedError 51 | 52 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) 53 | new_features_list.append(new_features) 54 | 55 | return new_xyz, torch.cat(new_features_list, dim=1) 56 | 57 | 58 | class PointnetSAModuleMSG(_PointnetSAModuleBase): 59 | """Pointnet set abstraction layer with multiscale grouping""" 60 | 61 | def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True, 62 | use_xyz: bool = True, pool_method='max_pool', instance_norm=False): 63 | """ 64 | :param npoint: int 65 | :param radii: list of float, list of radii to group with 66 | :param nsamples: list of int, number of samples in each ball query 67 | :param mlps: list of list of int, spec of the pointnet before the global pooling for each scale 68 | :param bn: whether to use batchnorm 69 | :param use_xyz: 70 | :param pool_method: max_pool / avg_pool 71 | :param instance_norm: whether to use instance_norm 72 | """ 73 | super().__init__() 74 | 75 | assert len(radii) == len(nsamples) == len(mlps) 76 | 77 | self.npoint = npoint 78 | self.groupers = nn.ModuleList() 79 | self.mlps = nn.ModuleList() 80 | for i in range(len(radii)): 81 | radius = radii[i] 82 | nsample = nsamples[i] 83 | self.groupers.append( 84 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) 85 | if npoint is not None else pointnet2_utils.GroupAll(use_xyz) 86 | ) 87 | mlp_spec = mlps[i] 88 | if use_xyz: 89 | mlp_spec[0] += 3 90 | 91 | self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm)) 92 | self.pool_method = pool_method 93 | 94 | 95 | class PointnetSAModule(PointnetSAModuleMSG): 96 | """Pointnet set abstraction layer""" 97 | 98 | def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None, 99 | bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False): 100 | """ 101 | :param mlp: list of int, spec of the pointnet before the global max_pool 102 | :param npoint: int, number of features 103 | :param radius: float, radius of ball 104 | :param nsample: int, number of samples in the ball query 105 | :param bn: whether to use batchnorm 106 | :param use_xyz: 107 | :param pool_method: max_pool / avg_pool 108 | :param instance_norm: whether to use instance_norm 109 | """ 110 | super().__init__( 111 | mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz, 112 | pool_method=pool_method, instance_norm=instance_norm 113 | ) 114 | 115 | 116 | class PointnetFPModule(nn.Module): 117 | r"""Propigates the features of one set to another""" 118 | 119 | def __init__(self, *, mlp: List[int], bn: bool = True): 120 | """ 121 | :param mlp: list of int 122 | :param bn: whether to use batchnorm 123 | """ 124 | super().__init__() 125 | self.mlp = pt_utils.SharedMLP(mlp, bn=bn) 126 | 127 | def forward( 128 | self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor 129 | ) -> torch.Tensor: 130 | """ 131 | :param unknown: (B, n, 3) tensor of the xyz positions of the unknown features 132 | :param known: (B, m, 3) tensor of the xyz positions of the known features 133 | :param unknow_feats: (B, C1, n) tensor of the features to be propigated to 134 | :param known_feats: (B, C2, m) tensor of features to be propigated 135 | :return: 136 | new_features: (B, mlp[-1], n) tensor of the features of the unknown features 137 | """ 138 | if known is not None: 139 | dist, idx = pointnet2_utils.three_nn(unknown, known) 140 | dist_recip = 1.0 / (dist + 1e-8) 141 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 142 | weight = dist_recip / norm 143 | 144 | interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight) 145 | else: 146 | interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1)) 147 | 148 | if unknow_feats is not None: 149 | new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n) 150 | else: 151 | new_features = interpolated_feats 152 | 153 | new_features = new_features.unsqueeze(-1) 154 | new_features = self.mlp(new_features) 155 | 156 | return new_features.squeeze(-1) 157 | 158 | 159 | if __name__ == "__main__": 160 | pass 161 | -------------------------------------------------------------------------------- /lib/csrc/pointnet2/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from typing import List, Tuple 3 | 4 | 5 | class SharedMLP(nn.Sequential): 6 | 7 | def __init__( 8 | self, 9 | args: List[int], 10 | *, 11 | bn: bool = False, 12 | activation=nn.ReLU(inplace=True), 13 | preact: bool = False, 14 | first: bool = False, 15 | name: str = "", 16 | instance_norm: bool = False, 17 | ): 18 | super().__init__() 19 | 20 | for i in range(len(args) - 1): 21 | self.add_module( 22 | name + 'layer{}'.format(i), 23 | Conv2d( 24 | args[i], 25 | args[i + 1], 26 | bn=(not first or not preact or (i != 0)) and bn, 27 | activation=activation 28 | if (not first or not preact or (i != 0)) else None, 29 | preact=preact, 30 | instance_norm=instance_norm 31 | ) 32 | ) 33 | 34 | 35 | class _ConvBase(nn.Sequential): 36 | 37 | def __init__( 38 | self, 39 | in_size, 40 | out_size, 41 | kernel_size, 42 | stride, 43 | padding, 44 | activation, 45 | bn, 46 | init, 47 | conv=None, 48 | batch_norm=None, 49 | bias=True, 50 | preact=False, 51 | name="", 52 | instance_norm=False, 53 | instance_norm_func=None 54 | ): 55 | super().__init__() 56 | 57 | bias = bias and (not bn) 58 | conv_unit = conv( 59 | in_size, 60 | out_size, 61 | kernel_size=kernel_size, 62 | stride=stride, 63 | padding=padding, 64 | bias=bias 65 | ) 66 | init(conv_unit.weight) 67 | if bias: 68 | nn.init.constant_(conv_unit.bias, 0) 69 | 70 | if bn: 71 | if not preact: 72 | bn_unit = batch_norm(out_size) 73 | else: 74 | bn_unit = batch_norm(in_size) 75 | if instance_norm: 76 | if not preact: 77 | in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False) 78 | else: 79 | in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False) 80 | 81 | if preact: 82 | if bn: 83 | self.add_module(name + 'bn', bn_unit) 84 | 85 | if activation is not None: 86 | self.add_module(name + 'activation', activation) 87 | 88 | if not bn and instance_norm: 89 | self.add_module(name + 'in', in_unit) 90 | 91 | self.add_module(name + 'conv', conv_unit) 92 | 93 | if not preact: 94 | if bn: 95 | self.add_module(name + 'bn', bn_unit) 96 | 97 | if activation is not None: 98 | self.add_module(name + 'activation', activation) 99 | 100 | if not bn and instance_norm: 101 | self.add_module(name + 'in', in_unit) 102 | 103 | 104 | class _BNBase(nn.Sequential): 105 | 106 | def __init__(self, in_size, batch_norm=None, name=""): 107 | super().__init__() 108 | self.add_module(name + "bn", batch_norm(in_size)) 109 | 110 | nn.init.constant_(self[0].weight, 1.0) 111 | nn.init.constant_(self[0].bias, 0) 112 | 113 | 114 | class BatchNorm1d(_BNBase): 115 | 116 | def __init__(self, in_size: int, *, name: str = ""): 117 | super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) 118 | 119 | 120 | class BatchNorm2d(_BNBase): 121 | 122 | def __init__(self, in_size: int, name: str = ""): 123 | super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) 124 | 125 | 126 | class Conv1d(_ConvBase): 127 | 128 | def __init__( 129 | self, 130 | in_size: int, 131 | out_size: int, 132 | *, 133 | kernel_size: int = 1, 134 | stride: int = 1, 135 | padding: int = 0, 136 | activation=nn.ReLU(inplace=True), 137 | bn: bool = False, 138 | init=nn.init.kaiming_normal_, 139 | bias: bool = True, 140 | preact: bool = False, 141 | name: str = "", 142 | instance_norm=False 143 | ): 144 | super().__init__( 145 | in_size, 146 | out_size, 147 | kernel_size, 148 | stride, 149 | padding, 150 | activation, 151 | bn, 152 | init, 153 | conv=nn.Conv1d, 154 | batch_norm=BatchNorm1d, 155 | bias=bias, 156 | preact=preact, 157 | name=name, 158 | instance_norm=instance_norm, 159 | instance_norm_func=nn.InstanceNorm1d 160 | ) 161 | 162 | 163 | class Conv2d(_ConvBase): 164 | 165 | def __init__( 166 | self, 167 | in_size: int, 168 | out_size: int, 169 | *, 170 | kernel_size: Tuple[int, int] = (1, 1), 171 | stride: Tuple[int, int] = (1, 1), 172 | padding: Tuple[int, int] = (0, 0), 173 | activation=nn.ReLU(inplace=True), 174 | bn: bool = False, 175 | init=nn.init.kaiming_normal_, 176 | bias: bool = True, 177 | preact: bool = False, 178 | name: str = "", 179 | instance_norm=False 180 | ): 181 | super().__init__( 182 | in_size, 183 | out_size, 184 | kernel_size, 185 | stride, 186 | padding, 187 | activation, 188 | bn, 189 | init, 190 | conv=nn.Conv2d, 191 | batch_norm=BatchNorm2d, 192 | bias=bias, 193 | preact=preact, 194 | name=name, 195 | instance_norm=instance_norm, 196 | instance_norm_func=nn.InstanceNorm2d 197 | ) 198 | 199 | 200 | class FC(nn.Sequential): 201 | 202 | def __init__( 203 | self, 204 | in_size: int, 205 | out_size: int, 206 | *, 207 | activation=nn.ReLU(inplace=True), 208 | bn: bool = False, 209 | init=None, 210 | preact: bool = False, 211 | name: str = "" 212 | ): 213 | super().__init__() 214 | 215 | fc = nn.Linear(in_size, out_size, bias=not bn) 216 | if init is not None: 217 | init(fc.weight) 218 | if not bn: 219 | nn.init.constant(fc.bias, 0) 220 | 221 | if preact: 222 | if bn: 223 | self.add_module(name + 'bn', BatchNorm1d(in_size)) 224 | 225 | if activation is not None: 226 | self.add_module(name + 'activation', activation) 227 | 228 | self.add_module(name + 'fc', fc) 229 | 230 | if not preact: 231 | if bn: 232 | self.add_module(name + 'bn', BatchNorm1d(out_size)) 233 | 234 | if activation is not None: 235 | self.add_module(name + 'activation', activation) 236 | 237 | -------------------------------------------------------------------------------- /lib/csrc/pointnet2/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='pointnet2', 6 | ext_modules=[ 7 | CUDAExtension('pointnet2_cuda', [ 8 | 'src/pointnet2_api.cpp', 9 | 10 | 'src/ball_query.cpp', 11 | 'src/ball_query_gpu.cu', 12 | 'src/group_points.cpp', 13 | 'src/group_points_gpu.cu', 14 | 'src/interpolate.cpp', 15 | 'src/interpolate_gpu.cu', 16 | 'src/sampling.cpp', 17 | 'src/sampling_gpu.cu', 18 | ], 19 | extra_compile_args={'cxx': ['-g'], 20 | 'nvcc': ['-O2']}) 21 | ], 22 | cmdclass={'build_ext': BuildExtension} 23 | ) 24 | -------------------------------------------------------------------------------- /lib/csrc/pointnet2/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "ball_query_gpu.h" 7 | 8 | extern THCState *state; 9 | 10 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") 11 | #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) 13 | 14 | int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, 15 | at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) { 16 | CHECK_INPUT(new_xyz_tensor); 17 | CHECK_INPUT(xyz_tensor); 18 | const float *new_xyz = new_xyz_tensor.data(); 19 | const float *xyz = xyz_tensor.data(); 20 | int *idx = idx_tensor.data(); 21 | 22 | cudaStream_t stream = THCState_getCurrentStream(state); 23 | ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx, stream); 24 | return 1; 25 | } -------------------------------------------------------------------------------- /lib/csrc/pointnet2/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "ball_query_gpu.h" 6 | #include "cuda_utils.h" 7 | 8 | 9 | __global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample, 10 | const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) { 11 | // new_xyz: (B, M, 3) 12 | // xyz: (B, N, 3) 13 | // output: 14 | // idx: (B, M, nsample) 15 | int bs_idx = blockIdx.y; 16 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 17 | if (bs_idx >= b || pt_idx >= m) return; 18 | 19 | new_xyz += bs_idx * m * 3 + pt_idx * 3; 20 | xyz += bs_idx * n * 3; 21 | idx += bs_idx * m * nsample + pt_idx * nsample; 22 | 23 | float radius2 = radius * radius; 24 | float new_x = new_xyz[0]; 25 | float new_y = new_xyz[1]; 26 | float new_z = new_xyz[2]; 27 | 28 | int cnt = 0; 29 | for (int k = 0; k < n; ++k) { 30 | float x = xyz[k * 3 + 0]; 31 | float y = xyz[k * 3 + 1]; 32 | float z = xyz[k * 3 + 2]; 33 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); 34 | if (d2 < radius2){ 35 | if (cnt == 0){ 36 | for (int l = 0; l < nsample; ++l) { 37 | idx[l] = k; 38 | } 39 | } 40 | idx[cnt] = k; 41 | ++cnt; 42 | if (cnt >= nsample) break; 43 | } 44 | } 45 | } 46 | 47 | 48 | void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \ 49 | const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) { 50 | // new_xyz: (B, M, 3) 51 | // xyz: (B, N, 3) 52 | // output: 53 | // idx: (B, M, nsample) 54 | 55 | cudaError_t err; 56 | 57 | dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) 58 | dim3 threads(THREADS_PER_BLOCK); 59 | 60 | ball_query_kernel_fast<<>>(b, n, m, radius, nsample, new_xyz, xyz, idx); 61 | // cudaDeviceSynchronize(); // for using printf in kernel function 62 | err = cudaGetLastError(); 63 | if (cudaSuccess != err) { 64 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 65 | exit(-1); 66 | } 67 | } -------------------------------------------------------------------------------- /lib/csrc/pointnet2/src/ball_query_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _BALL_QUERY_GPU_H 2 | #define _BALL_QUERY_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, 10 | at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor); 11 | 12 | void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, 13 | const float *xyz, const float *new_xyz, int *idx, cudaStream_t stream); 14 | 15 | #endif 16 | -------------------------------------------------------------------------------- /lib/csrc/pointnet2/src/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | 6 | #define TOTAL_THREADS 1024 7 | #define THREADS_PER_BLOCK 256 8 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) 9 | 10 | inline int opt_n_threads(int work_size) { 11 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 12 | 13 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 14 | } 15 | #endif 16 | -------------------------------------------------------------------------------- /lib/csrc/pointnet2/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "group_points_gpu.h" 7 | 8 | extern THCState *state; 9 | 10 | 11 | int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, 12 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { 13 | 14 | float *grad_points = grad_points_tensor.data(); 15 | const int *idx = idx_tensor.data(); 16 | const float *grad_out = grad_out_tensor.data(); 17 | 18 | cudaStream_t stream = THCState_getCurrentStream(state); 19 | 20 | group_points_grad_kernel_launcher_fast(b, c, n, npoints, nsample, grad_out, idx, grad_points, stream); 21 | return 1; 22 | } 23 | 24 | 25 | int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, 26 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) { 27 | 28 | const float *points = points_tensor.data(); 29 | const int *idx = idx_tensor.data(); 30 | float *out = out_tensor.data(); 31 | 32 | cudaStream_t stream = THCState_getCurrentStream(state); 33 | 34 | group_points_kernel_launcher_fast(b, c, n, npoints, nsample, points, idx, out, stream); 35 | return 1; 36 | } -------------------------------------------------------------------------------- /lib/csrc/pointnet2/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | #include "group_points_gpu.h" 6 | 7 | 8 | __global__ void group_points_grad_kernel_fast(int b, int c, int n, int npoints, int nsample, 9 | const float *__restrict__ grad_out, const int *__restrict__ idx, float *__restrict__ grad_points) { 10 | // grad_out: (B, C, npoints, nsample) 11 | // idx: (B, npoints, nsample) 12 | // output: 13 | // grad_points: (B, C, N) 14 | int bs_idx = blockIdx.z; 15 | int c_idx = blockIdx.y; 16 | int index = blockIdx.x * blockDim.x + threadIdx.x; 17 | int pt_idx = index / nsample; 18 | if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; 19 | 20 | int sample_idx = index % nsample; 21 | grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; 22 | idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; 23 | 24 | atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0] , grad_out[0]); 25 | } 26 | 27 | void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 28 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) { 29 | // grad_out: (B, C, npoints, nsample) 30 | // idx: (B, npoints, nsample) 31 | // output: 32 | // grad_points: (B, C, N) 33 | cudaError_t err; 34 | dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 35 | dim3 threads(THREADS_PER_BLOCK); 36 | 37 | group_points_grad_kernel_fast<<>>(b, c, n, npoints, nsample, grad_out, idx, grad_points); 38 | 39 | err = cudaGetLastError(); 40 | if (cudaSuccess != err) { 41 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 42 | exit(-1); 43 | } 44 | } 45 | 46 | 47 | __global__ void group_points_kernel_fast(int b, int c, int n, int npoints, int nsample, 48 | const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) { 49 | // points: (B, C, N) 50 | // idx: (B, npoints, nsample) 51 | // output: 52 | // out: (B, C, npoints, nsample) 53 | int bs_idx = blockIdx.z; 54 | int c_idx = blockIdx.y; 55 | int index = blockIdx.x * blockDim.x + threadIdx.x; 56 | int pt_idx = index / nsample; 57 | if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; 58 | 59 | int sample_idx = index % nsample; 60 | 61 | idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; 62 | int in_idx = bs_idx * c * n + c_idx * n + idx[0]; 63 | int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; 64 | 65 | out[out_idx] = points[in_idx]; 66 | } 67 | 68 | 69 | void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 70 | const float *points, const int *idx, float *out, cudaStream_t stream) { 71 | // points: (B, C, N) 72 | // idx: (B, npoints, nsample) 73 | // output: 74 | // out: (B, C, npoints, nsample) 75 | cudaError_t err; 76 | dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 77 | dim3 threads(THREADS_PER_BLOCK); 78 | 79 | group_points_kernel_fast<<>>(b, c, n, npoints, nsample, points, idx, out); 80 | // cudaDeviceSynchronize(); // for using printf in kernel function 81 | err = cudaGetLastError(); 82 | if (cudaSuccess != err) { 83 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 84 | exit(-1); 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /lib/csrc/pointnet2/src/group_points_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _GROUP_POINTS_GPU_H 2 | #define _GROUP_POINTS_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, 11 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); 12 | 13 | void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 14 | const float *points, const int *idx, float *out, cudaStream_t stream); 15 | 16 | int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, 17 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); 18 | 19 | void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 20 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream); 21 | 22 | #endif 23 | -------------------------------------------------------------------------------- /lib/csrc/pointnet2/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include "interpolate_gpu.h" 10 | 11 | extern THCState *state; 12 | 13 | 14 | void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, 15 | at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) { 16 | const float *unknown = unknown_tensor.data(); 17 | const float *known = known_tensor.data(); 18 | float *dist2 = dist2_tensor.data(); 19 | int *idx = idx_tensor.data(); 20 | 21 | cudaStream_t stream = THCState_getCurrentStream(state); 22 | three_nn_kernel_launcher_fast(b, n, m, unknown, known, dist2, idx, stream); 23 | } 24 | 25 | 26 | void three_interpolate_wrapper_fast(int b, int c, int m, int n, 27 | at::Tensor points_tensor, 28 | at::Tensor idx_tensor, 29 | at::Tensor weight_tensor, 30 | at::Tensor out_tensor) { 31 | 32 | const float *points = points_tensor.data(); 33 | const float *weight = weight_tensor.data(); 34 | float *out = out_tensor.data(); 35 | const int *idx = idx_tensor.data(); 36 | 37 | cudaStream_t stream = THCState_getCurrentStream(state); 38 | three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out, stream); 39 | } 40 | 41 | void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, 42 | at::Tensor grad_out_tensor, 43 | at::Tensor idx_tensor, 44 | at::Tensor weight_tensor, 45 | at::Tensor grad_points_tensor) { 46 | 47 | const float *grad_out = grad_out_tensor.data(); 48 | const float *weight = weight_tensor.data(); 49 | float *grad_points = grad_points_tensor.data(); 50 | const int *idx = idx_tensor.data(); 51 | 52 | cudaStream_t stream = THCState_getCurrentStream(state); 53 | three_interpolate_grad_kernel_launcher_fast(b, c, n, m, grad_out, idx, weight, grad_points, stream); 54 | } -------------------------------------------------------------------------------- /lib/csrc/pointnet2/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | #include "interpolate_gpu.h" 7 | 8 | 9 | __global__ void three_nn_kernel_fast(int b, int n, int m, const float *__restrict__ unknown, 10 | const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) { 11 | // unknown: (B, N, 3) 12 | // known: (B, M, 3) 13 | // output: 14 | // dist2: (B, N, 3) 15 | // idx: (B, N, 3) 16 | 17 | int bs_idx = blockIdx.y; 18 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 19 | if (bs_idx >= b || pt_idx >= n) return; 20 | 21 | unknown += bs_idx * n * 3 + pt_idx * 3; 22 | known += bs_idx * m * 3; 23 | dist2 += bs_idx * n * 3 + pt_idx * 3; 24 | idx += bs_idx * n * 3 + pt_idx * 3; 25 | 26 | float ux = unknown[0]; 27 | float uy = unknown[1]; 28 | float uz = unknown[2]; 29 | 30 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 31 | int besti1 = 0, besti2 = 0, besti3 = 0; 32 | for (int k = 0; k < m; ++k) { 33 | float x = known[k * 3 + 0]; 34 | float y = known[k * 3 + 1]; 35 | float z = known[k * 3 + 2]; 36 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 37 | if (d < best1) { 38 | best3 = best2; besti3 = besti2; 39 | best2 = best1; besti2 = besti1; 40 | best1 = d; besti1 = k; 41 | } 42 | else if (d < best2) { 43 | best3 = best2; besti3 = besti2; 44 | best2 = d; besti2 = k; 45 | } 46 | else if (d < best3) { 47 | best3 = d; besti3 = k; 48 | } 49 | } 50 | dist2[0] = best1; dist2[1] = best2; dist2[2] = best3; 51 | idx[0] = besti1; idx[1] = besti2; idx[2] = besti3; 52 | } 53 | 54 | 55 | void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown, 56 | const float *known, float *dist2, int *idx, cudaStream_t stream) { 57 | // unknown: (B, N, 3) 58 | // known: (B, M, 3) 59 | // output: 60 | // dist2: (B, N, 3) 61 | // idx: (B, N, 3) 62 | 63 | cudaError_t err; 64 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) 65 | dim3 threads(THREADS_PER_BLOCK); 66 | 67 | three_nn_kernel_fast<<>>(b, n, m, unknown, known, dist2, idx); 68 | 69 | err = cudaGetLastError(); 70 | if (cudaSuccess != err) { 71 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 72 | exit(-1); 73 | } 74 | } 75 | 76 | 77 | __global__ void three_interpolate_kernel_fast(int b, int c, int m, int n, const float *__restrict__ points, 78 | const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ out) { 79 | // points: (B, C, M) 80 | // idx: (B, N, 3) 81 | // weight: (B, N, 3) 82 | // output: 83 | // out: (B, C, N) 84 | 85 | int bs_idx = blockIdx.z; 86 | int c_idx = blockIdx.y; 87 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 88 | 89 | if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; 90 | 91 | weight += bs_idx * n * 3 + pt_idx * 3; 92 | points += bs_idx * c * m + c_idx * m; 93 | idx += bs_idx * n * 3 + pt_idx * 3; 94 | out += bs_idx * c * n + c_idx * n; 95 | 96 | out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + weight[2] * points[idx[2]]; 97 | } 98 | 99 | void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, 100 | const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream) { 101 | // points: (B, C, M) 102 | // idx: (B, N, 3) 103 | // weight: (B, N, 3) 104 | // output: 105 | // out: (B, C, N) 106 | 107 | cudaError_t err; 108 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 109 | dim3 threads(THREADS_PER_BLOCK); 110 | three_interpolate_kernel_fast<<>>(b, c, m, n, points, idx, weight, out); 111 | 112 | err = cudaGetLastError(); 113 | if (cudaSuccess != err) { 114 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 115 | exit(-1); 116 | } 117 | } 118 | 119 | 120 | __global__ void three_interpolate_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out, 121 | const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ grad_points) { 122 | // grad_out: (B, C, N) 123 | // weight: (B, N, 3) 124 | // output: 125 | // grad_points: (B, C, M) 126 | 127 | int bs_idx = blockIdx.z; 128 | int c_idx = blockIdx.y; 129 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 130 | 131 | if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; 132 | 133 | grad_out += bs_idx * c * n + c_idx * n + pt_idx; 134 | weight += bs_idx * n * 3 + pt_idx * 3; 135 | grad_points += bs_idx * c * m + c_idx * m; 136 | idx += bs_idx * n * 3 + pt_idx * 3; 137 | 138 | 139 | atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]); 140 | atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]); 141 | atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]); 142 | } 143 | 144 | void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, 145 | const int *idx, const float *weight, float *grad_points, cudaStream_t stream) { 146 | // grad_out: (B, C, N) 147 | // weight: (B, N, 3) 148 | // output: 149 | // grad_points: (B, C, M) 150 | 151 | cudaError_t err; 152 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b); // blockIdx.x(col), blockIdx.y(row) 153 | dim3 threads(THREADS_PER_BLOCK); 154 | three_interpolate_grad_kernel_fast<<>>(b, c, n, m, grad_out, idx, weight, grad_points); 155 | 156 | err = cudaGetLastError(); 157 | if (cudaSuccess != err) { 158 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 159 | exit(-1); 160 | } 161 | } -------------------------------------------------------------------------------- /lib/csrc/pointnet2/src/interpolate_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _INTERPOLATE_GPU_H 2 | #define _INTERPOLATE_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, 11 | at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor); 12 | 13 | void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown, 14 | const float *known, float *dist2, int *idx, cudaStream_t stream); 15 | 16 | 17 | void three_interpolate_wrapper_fast(int b, int c, int m, int n, at::Tensor points_tensor, 18 | at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor); 19 | 20 | void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, 21 | const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream); 22 | 23 | 24 | void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, at::Tensor grad_out_tensor, 25 | at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_points_tensor); 26 | 27 | void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, 28 | const int *idx, const float *weight, float *grad_points, cudaStream_t stream); 29 | 30 | #endif 31 | -------------------------------------------------------------------------------- /lib/csrc/pointnet2/src/pointnet2_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "ball_query_gpu.h" 5 | #include "group_points_gpu.h" 6 | #include "sampling_gpu.h" 7 | #include "interpolate_gpu.h" 8 | 9 | 10 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 11 | m.def("ball_query_wrapper", &ball_query_wrapper_fast, "ball_query_wrapper_fast"); 12 | 13 | m.def("group_points_wrapper", &group_points_wrapper_fast, "group_points_wrapper_fast"); 14 | m.def("group_points_grad_wrapper", &group_points_grad_wrapper_fast, "group_points_grad_wrapper_fast"); 15 | 16 | m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast"); 17 | m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast"); 18 | 19 | m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper"); 20 | 21 | m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast"); 22 | m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast"); 23 | m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_fast, "three_interpolate_grad_wrapper_fast"); 24 | } 25 | -------------------------------------------------------------------------------- /lib/csrc/pointnet2/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "sampling_gpu.h" 7 | 8 | extern THCState *state; 9 | 10 | 11 | int gather_points_wrapper_fast(int b, int c, int n, int npoints, 12 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor){ 13 | const float *points = points_tensor.data(); 14 | const int *idx = idx_tensor.data(); 15 | float *out = out_tensor.data(); 16 | 17 | cudaStream_t stream = THCState_getCurrentStream(state); 18 | gather_points_kernel_launcher_fast(b, c, n, npoints, points, idx, out, stream); 19 | return 1; 20 | } 21 | 22 | 23 | int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, 24 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) { 25 | 26 | const float *grad_out = grad_out_tensor.data(); 27 | const int *idx = idx_tensor.data(); 28 | float *grad_points = grad_points_tensor.data(); 29 | 30 | cudaStream_t stream = THCState_getCurrentStream(state); 31 | gather_points_grad_kernel_launcher_fast(b, c, n, npoints, grad_out, idx, grad_points, stream); 32 | return 1; 33 | } 34 | 35 | 36 | int furthest_point_sampling_wrapper(int b, int n, int m, 37 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) { 38 | 39 | const float *points = points_tensor.data(); 40 | float *temp = temp_tensor.data(); 41 | int *idx = idx_tensor.data(); 42 | 43 | cudaStream_t stream = THCState_getCurrentStream(state); 44 | furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream); 45 | return 1; 46 | } 47 | -------------------------------------------------------------------------------- /lib/csrc/pointnet2/src/sampling_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _SAMPLING_GPU_H 2 | #define _SAMPLING_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | 9 | int gather_points_wrapper_fast(int b, int c, int n, int npoints, 10 | at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); 11 | 12 | void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, 13 | const float *points, const int *idx, float *out, cudaStream_t stream); 14 | 15 | 16 | int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, 17 | at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); 18 | 19 | void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, 20 | const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream); 21 | 22 | 23 | int furthest_point_sampling_wrapper(int b, int n, int m, 24 | at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor); 25 | 26 | void furthest_point_sampling_kernel_launcher(int b, int n, int m, 27 | const float *dataset, float *temp, int *idxs, cudaStream_t stream); 28 | 29 | #endif 30 | -------------------------------------------------------------------------------- /lib/csrc/torchsearchsorted/.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Object files 5 | *.o 6 | *.ko 7 | *.obj 8 | *.elf 9 | 10 | # Linker output 11 | *.ilk 12 | *.map 13 | *.exp 14 | 15 | # Precompiled Headers 16 | *.gch 17 | *.pch 18 | 19 | # Libraries 20 | *.lib 21 | *.a 22 | *.la 23 | *.lo 24 | 25 | # Shared objects (inc. Windows DLLs) 26 | *.dll 27 | *.so 28 | *.so.* 29 | *.dylib 30 | 31 | # Executables 32 | *.exe 33 | *.out 34 | *.app 35 | *.i*86 36 | *.x86_64 37 | *.hex 38 | 39 | # Debug files 40 | *.dSYM/ 41 | *.su 42 | *.idb 43 | *.pdb 44 | 45 | # Kernel Module Compile Results 46 | *.mod* 47 | *.cmd 48 | .tmp_versions/ 49 | modules.order 50 | Module.symvers 51 | Mkfile.old 52 | dkms.conf 53 | 54 | 55 | # Byte-compiled / optimized / DLL files 56 | __pycache__/ 57 | *.py[cod] 58 | *$py.class 59 | 60 | # C extensions 61 | *.so 62 | 63 | # Distribution / packaging 64 | .Python 65 | build/ 66 | develop-eggs/ 67 | dist/ 68 | downloads/ 69 | eggs/ 70 | .eggs/ 71 | lib/ 72 | lib64/ 73 | parts/ 74 | sdist/ 75 | var/ 76 | wheels/ 77 | *.egg-info/ 78 | .installed.cfg 79 | *.egg 80 | MANIFEST 81 | 82 | # PyInstaller 83 | # Usually these files are written by a python script from a template 84 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 85 | *.manifest 86 | *.spec 87 | 88 | # Installer logs 89 | pip-log.txt 90 | pip-delete-this-directory.txt 91 | 92 | # Unit test / coverage reports 93 | htmlcov/ 94 | .tox/ 95 | .coverage 96 | .coverage.* 97 | .cache 98 | nosetests.xml 99 | coverage.xml 100 | *.cover 101 | .hypothesis/ 102 | .pytest_cache/ 103 | 104 | # Translations 105 | *.mo 106 | *.pot 107 | 108 | # Django stuff: 109 | *.log 110 | local_settings.py 111 | db.sqlite3 112 | 113 | # Flask stuff: 114 | instance/ 115 | .webassets-cache 116 | 117 | # Scrapy stuff: 118 | .scrapy 119 | 120 | # Sphinx documentation 121 | docs/_build/ 122 | 123 | # PyBuilder 124 | target/ 125 | 126 | # Jupyter Notebook 127 | .ipynb_checkpoints 128 | 129 | # pyenv 130 | .python-version 131 | 132 | # celery beat schedule file 133 | celerybeat-schedule 134 | 135 | # SageMath parsed files 136 | *.sage.py 137 | 138 | # Environments 139 | .env 140 | .venv 141 | env/ 142 | venv/ 143 | ENV/ 144 | env.bak/ 145 | venv.bak/ 146 | 147 | # Spyder project settings 148 | .spyderproject 149 | .spyproject 150 | 151 | # Rope project settings 152 | .ropeproject 153 | 154 | # mkdocs documentation 155 | /site 156 | 157 | # mypy 158 | .mypy_cache/ 159 | -------------------------------------------------------------------------------- /lib/csrc/torchsearchsorted/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Inria (Antoine Liutkus) 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /lib/csrc/torchsearchsorted/README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Custom CUDA kernel for searchsorted 2 | 3 | This repository is an implementation of the searchsorted function to work for pytorch CUDA Tensors. Initially derived from the great [C extension tutorial](https://github.com/chrischoy/pytorch-custom-cuda-tutorial), but totally changed since then because building C extensions is not available anymore on pytorch 1.0. 4 | 5 | 6 | > Warnings: 7 | > * only works with pytorch > v1.3 and CUDA >= v10.1 8 | > * **NOTE** When using `searchsorted()` for practical applications, tensors need to be contiguous in memory. This can be easily achieved by calling `tensor.contiguous()` on the input tensors. Failing to do so _will_ lead to inconsistent results across applications. 9 | 10 | ## Description 11 | 12 | Implements a function `searchsorted(a, v, out, side)` that works just like the [numpy version](https://docs.scipy.org/doc/numpy/reference/generated/numpy.searchsorted.html#numpy.searchsorted) except that `a` and `v` are matrices. 13 | * `a` is of shape either `(1, ncols_a)` or `(nrows, ncols_a)`, and is contiguous in memory (do `a.contiguous()` to ensure this). 14 | * `v` is of shape either `(1, ncols_v)` or `(nrows, ncols_v)`, and is contiguous in memory (do `v.contiguous()` to ensure this). 15 | * `out` is either `None` or of shape `(nrows, ncols_v)`. If provided and of the right shape, the result is put there. This is to avoid costly memory allocations if the user already did it. If provided, `out` should be contiguous in memory too (do `out.contiguous()` to ensure this). 16 | * `side` is either "left" or "right". See the [numpy doc](https://docs.scipy.org/doc/numpy/reference/generated/numpy.searchsorted.html#numpy.searchsorted). Please not that the current implementation *does not correctly handle this parameter*. Help welcome to improve the speed of [this PR](https://github.com/aliutkus/torchsearchsorted/pull/7) 17 | 18 | the output is of size as `(nrows, ncols_v)`. If all input tensors are on GPU, a cuda version will be called. Otherwise, it will be on CPU. 19 | 20 | 21 | **Disclaimers** 22 | 23 | * This function has not been heavily tested. Use at your own risks 24 | * When `a` is not sorted, the results vary from numpy's version. But I decided not to care about this because the function should not be called in this case. 25 | * In some cases, the results vary from numpy's version. However, as far as I could see, this only happens when values are equal, which means we actually don't care about the order in which this value is added. I decided not to care about this also. 26 | * vectors have to be contiguous for torchsearchsorted to give consistant results. use `.contiguous()` on all tensor arguments before calling 27 | 28 | 29 | ## Installation 30 | 31 | Just `pip install .`, in the root folder of this repo. This will compile 32 | and install the torchsearchsorted module. 33 | 34 | be careful that sometimes, `nvcc` needs versions of `gcc` and `g++` that are older than those found by default on the system. If so, just create symbolic links to the right versions in your cuda/bin folder (where `nvcc` is) 35 | 36 | For instance, on my machine, I had `gcc` and `g++` v9 installed, but `nvcc` required v8. 37 | So I had to do: 38 | 39 | > sudo apt-get install g++-8 gcc-8 40 | > sudo ln -s /usr/bin/gcc-8 /usr/local/cuda-10.1/bin/gcc 41 | > sudo ln -s /usr/bin/g++-8 /usr/local/cuda-10.1/bin/g++ 42 | 43 | be careful that you need pytorch to be installed on your system. The code was tested on pytorch v1.3 44 | 45 | ## Usage 46 | 47 | Just import the torchsearchsorted package after installation. I typically do: 48 | 49 | ``` 50 | from torchsearchsorted import searchsorted 51 | ``` 52 | 53 | 54 | ## Testing 55 | 56 | Under the `examples` subfolder, you may: 57 | 58 | 1. try `python test.py` with `torch` available. 59 | 60 | ``` 61 | Looking for 50000x1000 values in 50000x300 entries 62 | NUMPY: searchsorted in 4851.592ms 63 | CPU: searchsorted in 4805.432ms 64 | difference between CPU and NUMPY: 0.000 65 | GPU: searchsorted in 1.055ms 66 | difference between GPU and NUMPY: 0.000 67 | 68 | Looking for 50000x1000 values in 50000x300 entries 69 | NUMPY: searchsorted in 4333.964ms 70 | CPU: searchsorted in 4753.958ms 71 | difference between CPU and NUMPY: 0.000 72 | GPU: searchsorted in 0.391ms 73 | difference between GPU and NUMPY: 0.000 74 | ``` 75 | The first run comprises the time of allocation, while the second one does not. 76 | 77 | 2. You may also use the nice `benchmark.py` code written by [@baldassarreFe](https://github.com/baldassarreFe), that tests `searchsorted` on many runs: 78 | 79 | ``` 80 | Benchmark searchsorted: 81 | - a [5000 x 300] 82 | - v [5000 x 100] 83 | - reporting fastest time of 20 runs 84 | - each run executes searchsorted 100 times 85 | 86 | Numpy: 4.6302046799100935 87 | CPU: 5.041533078998327 88 | CUDA: 0.0007955809123814106 89 | ``` 90 | -------------------------------------------------------------------------------- /lib/csrc/torchsearchsorted/examples/benchmark.py: -------------------------------------------------------------------------------- 1 | import timeit 2 | 3 | import torch 4 | import numpy as np 5 | from torchsearchsorted import searchsorted, numpy_searchsorted 6 | 7 | B = 5_000 8 | A = 300 9 | V = 100 10 | 11 | repeats = 20 12 | number = 100 13 | 14 | print( 15 | f'Benchmark searchsorted:', 16 | f'- a [{B} x {A}]', 17 | f'- v [{B} x {V}]', 18 | f'- reporting fastest time of {repeats} runs', 19 | f'- each run executes searchsorted {number} times', 20 | sep='\n', 21 | end='\n\n' 22 | ) 23 | 24 | 25 | def get_arrays(): 26 | a = np.sort(np.random.randn(B, A), axis=1) 27 | v = np.random.randn(B, V) 28 | out = np.empty_like(v, dtype=np.long) 29 | return a, v, out 30 | 31 | 32 | def get_tensors(device): 33 | a = torch.sort(torch.randn(B, A, device=device), dim=1)[0] 34 | v = torch.randn(B, V, device=device) 35 | out = torch.empty(B, V, device=device, dtype=torch.long) 36 | if torch.cuda.is_available(): 37 | torch.cuda.synchronize() 38 | return a, v, out 39 | 40 | def searchsorted_synchronized(a,v,out=None,side='left'): 41 | out = searchsorted(a,v,out,side) 42 | torch.cuda.synchronize() 43 | return out 44 | 45 | numpy = timeit.repeat( 46 | stmt="numpy_searchsorted(a, v, side='left')", 47 | setup="a, v, out = get_arrays()", 48 | globals=globals(), 49 | repeat=repeats, 50 | number=number 51 | ) 52 | print('Numpy: ', min(numpy), sep='\t') 53 | 54 | cpu = timeit.repeat( 55 | stmt="searchsorted(a, v, out, side='left')", 56 | setup="a, v, out = get_tensors(device='cpu')", 57 | globals=globals(), 58 | repeat=repeats, 59 | number=number 60 | ) 61 | print('CPU: ', min(cpu), sep='\t') 62 | 63 | if torch.cuda.is_available(): 64 | gpu = timeit.repeat( 65 | stmt="searchsorted_synchronized(a, v, out, side='left')", 66 | setup="a, v, out = get_tensors(device='cuda')", 67 | globals=globals(), 68 | repeat=repeats, 69 | number=number 70 | ) 71 | print('CUDA: ', min(gpu), sep='\t') 72 | -------------------------------------------------------------------------------- /lib/csrc/torchsearchsorted/examples/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchsearchsorted import searchsorted, numpy_searchsorted 3 | import time 4 | 5 | if __name__ == '__main__': 6 | # defining the number of tests 7 | ntests = 2 8 | 9 | # defining the problem dimensions 10 | nrows_a = 50000 11 | nrows_v = 50000 12 | nsorted_values = 300 13 | nvalues = 1000 14 | 15 | # defines the variables. The first run will comprise allocation, the 16 | # further ones will not 17 | test_GPU = None 18 | test_CPU = None 19 | 20 | for ntest in range(ntests): 21 | print("\nLooking for %dx%d values in %dx%d entries" % (nrows_v, nvalues, 22 | nrows_a, 23 | nsorted_values)) 24 | 25 | side = 'right' 26 | # generate a matrix with sorted rows 27 | a = torch.randn(nrows_a, nsorted_values, device='cpu') 28 | a = torch.sort(a, dim=1)[0] 29 | # generate a matrix of values to searchsort 30 | v = torch.randn(nrows_v, nvalues, device='cpu') 31 | 32 | # a = torch.tensor([[0., 1.]]) 33 | # v = torch.tensor([[1.]]) 34 | 35 | t0 = time.time() 36 | test_NP = torch.tensor(numpy_searchsorted(a, v, side)) 37 | print('NUMPY: searchsorted in %0.3fms' % (1000*(time.time()-t0))) 38 | t0 = time.time() 39 | test_CPU = searchsorted(a, v, test_CPU, side) 40 | print('CPU: searchsorted in %0.3fms' % (1000*(time.time()-t0))) 41 | # compute the difference between both 42 | error_CPU = torch.norm(test_NP.double() 43 | - test_CPU.double()).numpy() 44 | if error_CPU: 45 | import ipdb; ipdb.set_trace() 46 | print(' difference between CPU and NUMPY: %0.3f' % error_CPU) 47 | 48 | if not torch.cuda.is_available(): 49 | print('CUDA is not available on this machine, cannot go further.') 50 | continue 51 | else: 52 | # now do the CPU 53 | a = a.to('cuda') 54 | v = v.to('cuda') 55 | torch.cuda.synchronize() 56 | # launch searchsorted on those 57 | t0 = time.time() 58 | test_GPU = searchsorted(a, v, test_GPU, side) 59 | torch.cuda.synchronize() 60 | print('GPU: searchsorted in %0.3fms' % (1000*(time.time()-t0))) 61 | 62 | # compute the difference between both 63 | error_CUDA = torch.norm(test_NP.to('cuda').double() 64 | - test_GPU.double()).cpu().numpy() 65 | 66 | print(' difference between GPU and NUMPY: %0.3f' % error_CUDA) 67 | -------------------------------------------------------------------------------- /lib/csrc/torchsearchsorted/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from torch.utils.cpp_extension import BuildExtension, CUDA_HOME 3 | from torch.utils.cpp_extension import CppExtension, CUDAExtension 4 | 5 | # In any case, include the CPU version 6 | modules = [ 7 | CppExtension('torchsearchsorted.cpu', 8 | ['src/cpu/searchsorted_cpu_wrapper.cpp']), 9 | ] 10 | 11 | # If nvcc is available, add the CUDA extension 12 | if CUDA_HOME: 13 | modules.append( 14 | CUDAExtension('torchsearchsorted.cuda', 15 | ['src/cuda/searchsorted_cuda_wrapper.cpp', 16 | 'src/cuda/searchsorted_cuda_kernel.cu']) 17 | ) 18 | 19 | tests_require = [ 20 | 'pytest', 21 | ] 22 | 23 | # Now proceed to setup 24 | setup( 25 | name='torchsearchsorted', 26 | version='1.1', 27 | description='A searchsorted implementation for pytorch', 28 | keywords='searchsorted', 29 | author='Antoine Liutkus', 30 | author_email='antoine.liutkus@inria.fr', 31 | packages=find_packages(where='src'), 32 | package_dir={"": "src"}, 33 | ext_modules=modules, 34 | tests_require=tests_require, 35 | extras_require={ 36 | 'test': tests_require, 37 | }, 38 | cmdclass={ 39 | 'build_ext': BuildExtension 40 | } 41 | ) 42 | -------------------------------------------------------------------------------- /lib/csrc/torchsearchsorted/src/cpu/searchsorted_cpu_wrapper.cpp: -------------------------------------------------------------------------------- 1 | #include "searchsorted_cpu_wrapper.h" 2 | #include 3 | 4 | template 5 | int eval(scalar_t val, scalar_t *a, int64_t row, int64_t col, int64_t ncol, bool side_left) 6 | { 7 | /* Evaluates whether a[row,col] < val <= a[row, col+1]*/ 8 | 9 | if (col == ncol - 1) 10 | { 11 | // special case: we are on the right border 12 | if (a[row * ncol + col] <= val){ 13 | return 1;} 14 | else { 15 | return -1;} 16 | } 17 | bool is_lower; 18 | bool is_next_higher; 19 | 20 | if (side_left) { 21 | // a[row, col] < v <= a[row, col+1] 22 | is_lower = (a[row * ncol + col] < val); 23 | is_next_higher = (a[row*ncol + col + 1] >= val); 24 | } else { 25 | // a[row, col] <= v < a[row, col+1] 26 | is_lower = (a[row * ncol + col] <= val); 27 | is_next_higher = (a[row * ncol + col + 1] > val); 28 | } 29 | if (is_lower && is_next_higher) { 30 | // we found the right spot 31 | return 0; 32 | } else if (is_lower) { 33 | // answer is on the right side 34 | return 1; 35 | } else { 36 | // answer is on the left side 37 | return -1; 38 | } 39 | } 40 | 41 | template 42 | int64_t binary_search(scalar_t*a, int64_t row, scalar_t val, int64_t ncol, bool side_left) 43 | { 44 | /* Look for the value `val` within row `row` of matrix `a`, which 45 | has `ncol` columns. 46 | 47 | the `a` matrix is assumed sorted in increasing order, row-wise 48 | 49 | returns: 50 | * -1 if `val` is smaller than the smallest value found within that row of `a` 51 | * `ncol` - 1 if `val` is larger than the largest element of that row of `a` 52 | * Otherwise, return the column index `res` such that: 53 | - a[row, col] < val <= a[row, col+1]. (if side_left), or 54 | - a[row, col] < val <= a[row, col+1] (if not side_left). 55 | */ 56 | 57 | //start with left at 0 and right at number of columns of a 58 | int64_t right = ncol; 59 | int64_t left = 0; 60 | 61 | while (right >= left) { 62 | // take the midpoint of current left and right cursors 63 | int64_t mid = left + (right-left)/2; 64 | 65 | // check the relative position of val: are we good here ? 66 | int rel_pos = eval(val, a, row, mid, ncol, side_left); 67 | // we found the point 68 | if(rel_pos == 0) { 69 | return mid; 70 | } else if (rel_pos > 0) { 71 | if (mid==ncol-1){return ncol-1;} 72 | // the answer is on the right side 73 | left = mid; 74 | } else { 75 | if (mid==0){return -1;} 76 | right = mid; 77 | } 78 | } 79 | return -1; 80 | } 81 | 82 | void searchsorted_cpu_wrapper( 83 | at::Tensor a, 84 | at::Tensor v, 85 | at::Tensor res, 86 | bool side_left) 87 | { 88 | 89 | // Get the dimensions 90 | auto nrow_a = a.size(/*dim=*/0); 91 | auto ncol_a = a.size(/*dim=*/1); 92 | auto nrow_v = v.size(/*dim=*/0); 93 | auto ncol_v = v.size(/*dim=*/1); 94 | 95 | auto nrow_res = fmax(nrow_a, nrow_v); 96 | 97 | //auto acc_v = v.accessor(); 98 | //auto acc_res = res.accessor(); 99 | 100 | AT_DISPATCH_ALL_TYPES(a.type(), "searchsorted cpu", [&] { 101 | 102 | scalar_t* a_data = a.data_ptr(); 103 | scalar_t* v_data = v.data_ptr(); 104 | int64_t* res_data = res.data(); 105 | 106 | for (int64_t row = 0; row < nrow_res; row++) 107 | { 108 | for (int64_t col = 0; col < ncol_v; col++) 109 | { 110 | // get the value to look for 111 | int64_t row_in_v = (nrow_v == 1) ? 0 : row; 112 | int64_t row_in_a = (nrow_a == 1) ? 0 : row; 113 | 114 | int64_t idx_in_v = row_in_v * ncol_v + col; 115 | int64_t idx_in_res = row * ncol_v + col; 116 | 117 | // apply binary search 118 | res_data[idx_in_res] = (binary_search(a_data, row_in_a, v_data[idx_in_v], ncol_a, side_left) + 1); 119 | } 120 | } 121 | }); 122 | } 123 | 124 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 125 | m.def("searchsorted_cpu_wrapper", &searchsorted_cpu_wrapper, "searchsorted (CPU)"); 126 | } 127 | -------------------------------------------------------------------------------- /lib/csrc/torchsearchsorted/src/cpu/searchsorted_cpu_wrapper.h: -------------------------------------------------------------------------------- 1 | #ifndef _SEARCHSORTED_CPU 2 | #define _SEARCHSORTED_CPU 3 | 4 | #include 5 | 6 | void searchsorted_cpu_wrapper( 7 | at::Tensor a, 8 | at::Tensor v, 9 | at::Tensor res, 10 | bool side_left); 11 | 12 | #endif -------------------------------------------------------------------------------- /lib/csrc/torchsearchsorted/src/cuda/searchsorted_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "searchsorted_cuda_kernel.h" 2 | 3 | template 4 | __device__ 5 | int eval(scalar_t val, scalar_t *a, int64_t row, int64_t col, int64_t ncol, bool side_left) 6 | { 7 | /* Evaluates whether a[row,col] < val <= a[row, col+1]*/ 8 | 9 | if (col == ncol - 1) 10 | { 11 | // special case: we are on the right border 12 | if (a[row * ncol + col] <= val){ 13 | return 1;} 14 | else { 15 | return -1;} 16 | } 17 | bool is_lower; 18 | bool is_next_higher; 19 | 20 | if (side_left) { 21 | // a[row, col] < v <= a[row, col+1] 22 | is_lower = (a[row * ncol + col] < val); 23 | is_next_higher = (a[row*ncol + col + 1] >= val); 24 | } else { 25 | // a[row, col] <= v < a[row, col+1] 26 | is_lower = (a[row * ncol + col] <= val); 27 | is_next_higher = (a[row * ncol + col + 1] > val); 28 | } 29 | if (is_lower && is_next_higher) { 30 | // we found the right spot 31 | return 0; 32 | } else if (is_lower) { 33 | // answer is on the right side 34 | return 1; 35 | } else { 36 | // answer is on the left side 37 | return -1; 38 | } 39 | } 40 | 41 | template 42 | __device__ 43 | int binary_search(scalar_t *a, int64_t row, scalar_t val, int64_t ncol, bool side_left) 44 | { 45 | /* Look for the value `val` within row `row` of matrix `a`, which 46 | has `ncol` columns. 47 | 48 | the `a` matrix is assumed sorted in increasing order, row-wise 49 | 50 | Returns 51 | * -1 if `val` is smaller than the smallest value found within that row of `a` 52 | * `ncol` - 1 if `val` is larger than the largest element of that row of `a` 53 | * Otherwise, return the column index `res` such that: 54 | - a[row, col] < val <= a[row, col+1]. (if side_left), or 55 | - a[row, col] < val <= a[row, col+1] (if not side_left). 56 | */ 57 | 58 | //start with left at 0 and right at number of columns of a 59 | int64_t right = ncol; 60 | int64_t left = 0; 61 | 62 | while (right >= left) { 63 | // take the midpoint of current left and right cursors 64 | int64_t mid = left + (right-left)/2; 65 | 66 | // check the relative position of val: are we good here ? 67 | int rel_pos = eval(val, a, row, mid, ncol, side_left); 68 | // we found the point 69 | if(rel_pos == 0) { 70 | return mid; 71 | } else if (rel_pos > 0) { 72 | if (mid==ncol-1){return ncol-1;} 73 | // the answer is on the right side 74 | left = mid; 75 | } else { 76 | if (mid==0){return -1;} 77 | right = mid; 78 | } 79 | } 80 | return -1; 81 | } 82 | 83 | template 84 | __global__ 85 | void searchsorted_kernel( 86 | int64_t *res, 87 | scalar_t *a, 88 | scalar_t *v, 89 | int64_t nrow_res, int64_t nrow_a, int64_t nrow_v, int64_t ncol_a, int64_t ncol_v, bool side_left) 90 | { 91 | // get current row and column 92 | int64_t row = blockIdx.y*blockDim.y+threadIdx.y; 93 | int64_t col = blockIdx.x*blockDim.x+threadIdx.x; 94 | 95 | // check whether we are outside the bounds of what needs be computed. 96 | if ((row >= nrow_res) || (col >= ncol_v)) { 97 | return;} 98 | 99 | // get the value to look for 100 | int64_t row_in_v = (nrow_v==1) ? 0: row; 101 | int64_t row_in_a = (nrow_a==1) ? 0: row; 102 | int64_t idx_in_v = row_in_v*ncol_v+col; 103 | int64_t idx_in_res = row*ncol_v+col; 104 | 105 | // apply binary search 106 | res[idx_in_res] = binary_search(a, row_in_a, v[idx_in_v], ncol_a, side_left)+1; 107 | } 108 | 109 | 110 | void searchsorted_cuda( 111 | at::Tensor a, 112 | at::Tensor v, 113 | at::Tensor res, 114 | bool side_left){ 115 | 116 | // Get the dimensions 117 | auto nrow_a = a.size(/*dim=*/0); 118 | auto nrow_v = v.size(/*dim=*/0); 119 | auto ncol_a = a.size(/*dim=*/1); 120 | auto ncol_v = v.size(/*dim=*/1); 121 | 122 | auto nrow_res = fmax(double(nrow_a), double(nrow_v)); 123 | 124 | // prepare the kernel configuration 125 | dim3 threads(ncol_v, nrow_res); 126 | dim3 blocks(1, 1); 127 | if (nrow_res*ncol_v > 1024){ 128 | threads.x = int(fmin(double(1024), double(ncol_v))); 129 | threads.y = floor(1024/threads.x); 130 | blocks.x = ceil(double(ncol_v)/double(threads.x)); 131 | blocks.y = ceil(double(nrow_res)/double(threads.y)); 132 | } 133 | 134 | AT_DISPATCH_ALL_TYPES(a.type(), "searchsorted cuda", ([&] { 135 | searchsorted_kernel<<>>( 136 | res.data(), 137 | a.data(), 138 | v.data(), 139 | nrow_res, nrow_a, nrow_v, ncol_a, ncol_v, side_left); 140 | })); 141 | 142 | } 143 | -------------------------------------------------------------------------------- /lib/csrc/torchsearchsorted/src/cuda/searchsorted_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _SEARCHSORTED_CUDA_KERNEL 2 | #define _SEARCHSORTED_CUDA_KERNEL 3 | 4 | #include 5 | 6 | void searchsorted_cuda( 7 | at::Tensor a, 8 | at::Tensor v, 9 | at::Tensor res, 10 | bool side_left); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /lib/csrc/torchsearchsorted/src/cuda/searchsorted_cuda_wrapper.cpp: -------------------------------------------------------------------------------- 1 | #include "searchsorted_cuda_wrapper.h" 2 | 3 | // C++ interface 4 | 5 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 6 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 7 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 8 | 9 | void searchsorted_cuda_wrapper(at::Tensor a, at::Tensor v, at::Tensor res, bool side_left) 10 | { 11 | CHECK_INPUT(a); 12 | CHECK_INPUT(v); 13 | CHECK_INPUT(res); 14 | 15 | searchsorted_cuda(a, v, res, side_left); 16 | } 17 | 18 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 19 | m.def("searchsorted_cuda_wrapper", &searchsorted_cuda_wrapper, "searchsorted (CUDA)"); 20 | } 21 | -------------------------------------------------------------------------------- /lib/csrc/torchsearchsorted/src/cuda/searchsorted_cuda_wrapper.h: -------------------------------------------------------------------------------- 1 | #ifndef _SEARCHSORTED_CUDA_WRAPPER 2 | #define _SEARCHSORTED_CUDA_WRAPPER 3 | 4 | #include 5 | #include "searchsorted_cuda_kernel.h" 6 | 7 | void searchsorted_cuda_wrapper( 8 | at::Tensor a, 9 | at::Tensor v, 10 | at::Tensor res, 11 | bool side_left); 12 | 13 | #endif 14 | -------------------------------------------------------------------------------- /lib/csrc/torchsearchsorted/src/torchsearchsorted/__init__.py: -------------------------------------------------------------------------------- 1 | from .searchsorted import searchsorted 2 | from .utils import numpy_searchsorted 3 | -------------------------------------------------------------------------------- /lib/csrc/torchsearchsorted/src/torchsearchsorted/searchsorted.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | # trying to import the CPU searchsorted 6 | SEARCHSORTED_CPU_AVAILABLE = True 7 | try: 8 | from torchsearchsorted.cpu import searchsorted_cpu_wrapper 9 | except ImportError: 10 | SEARCHSORTED_CPU_AVAILABLE = False 11 | 12 | # trying to import the CUDA searchsorted 13 | SEARCHSORTED_GPU_AVAILABLE = True 14 | try: 15 | from torchsearchsorted.cuda import searchsorted_cuda_wrapper 16 | except ImportError: 17 | SEARCHSORTED_GPU_AVAILABLE = False 18 | 19 | 20 | def searchsorted(a: torch.Tensor, v: torch.Tensor, 21 | out: Optional[torch.LongTensor] = None, 22 | side='left') -> torch.LongTensor: 23 | assert len(a.shape) == 2, "input `a` must be 2-D." 24 | assert len(v.shape) == 2, "input `v` mus(t be 2-D." 25 | assert (a.shape[0] == v.shape[0] 26 | or a.shape[0] == 1 27 | or v.shape[0] == 1), ("`a` and `v` must have the same number of " 28 | "rows or one of them must have only one ") 29 | assert a.device == v.device, '`a` and `v` must be on the same device' 30 | 31 | result_shape = (max(a.shape[0], v.shape[0]), v.shape[1]) 32 | if out is not None: 33 | assert out.device == a.device, "`out` must be on the same device as `a`" 34 | assert out.dtype == torch.long, "out.dtype must be torch.long" 35 | assert out.shape == result_shape, ("If the output tensor is provided, " 36 | "its shape must be correct.") 37 | else: 38 | out = torch.empty(result_shape, device=v.device, dtype=torch.long) 39 | 40 | if a.is_cuda and not SEARCHSORTED_GPU_AVAILABLE: 41 | raise Exception('torchsearchsorted on CUDA device is asked, but it seems ' 42 | 'that it is not available. Please install it') 43 | if not a.is_cuda and not SEARCHSORTED_CPU_AVAILABLE: 44 | raise Exception('torchsearchsorted on CPU is not available. ' 45 | 'Please install it.') 46 | 47 | left_side = 1 if side=='left' else 0 48 | if a.is_cuda: 49 | searchsorted_cuda_wrapper(a, v, out, left_side) 50 | else: 51 | searchsorted_cpu_wrapper(a, v, out, left_side) 52 | 53 | return out 54 | -------------------------------------------------------------------------------- /lib/csrc/torchsearchsorted/src/torchsearchsorted/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def numpy_searchsorted(a: np.ndarray, v: np.ndarray, side='left'): 5 | """Numpy version of searchsorted that works batch-wise on pytorch tensors 6 | """ 7 | nrows_a = a.shape[0] 8 | (nrows_v, ncols_v) = v.shape 9 | nrows_out = max(nrows_a, nrows_v) 10 | out = np.empty((nrows_out, ncols_v), dtype=np.long) 11 | def sel(data, row): 12 | return data[0] if data.shape[0] == 1 else data[row] 13 | for row in range(nrows_out): 14 | out[row] = np.searchsorted(sel(a, row), sel(v, row), side=side) 15 | return out 16 | -------------------------------------------------------------------------------- /lib/csrc/torchsearchsorted/test/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | devices = {'cpu': torch.device('cpu')} 5 | if torch.cuda.is_available(): 6 | devices['cuda'] = torch.device('cuda:0') 7 | 8 | 9 | @pytest.fixture(params=devices.values(), ids=devices.keys()) 10 | def device(request): 11 | return request.param 12 | -------------------------------------------------------------------------------- /lib/csrc/torchsearchsorted/test/test_searchsorted.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | import numpy as np 5 | from torchsearchsorted import searchsorted, numpy_searchsorted 6 | from itertools import product, repeat 7 | 8 | 9 | def test_searchsorted_output_dtype(device): 10 | B = 100 11 | A = 50 12 | V = 12 13 | 14 | a = torch.sort(torch.rand(B, V, device=device), dim=1)[0] 15 | v = torch.rand(B, A, device=device) 16 | 17 | out = searchsorted(a, v) 18 | out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy()) 19 | assert out.dtype == torch.long 20 | np.testing.assert_array_equal(out.cpu().numpy(), out_np) 21 | 22 | out = torch.empty(v.shape, dtype=torch.long, device=device) 23 | searchsorted(a, v, out) 24 | assert out.dtype == torch.long 25 | np.testing.assert_array_equal(out.cpu().numpy(), out_np) 26 | 27 | Ba_val = [1, 100, 200] 28 | Bv_val = [1, 100, 200] 29 | A_val = [1, 50, 500] 30 | V_val = [1, 12, 120] 31 | side_val = ['left', 'right'] 32 | nrepeat = 100 33 | 34 | @pytest.mark.parametrize('Ba,Bv,A,V,side', product(Ba_val, Bv_val, A_val, V_val, side_val)) 35 | def test_searchsorted_correct(Ba, Bv, A, V, side, device): 36 | if Ba > 1 and Bv > 1 and Ba != Bv: 37 | return 38 | for test in range(nrepeat): 39 | a = torch.sort(torch.rand(Ba, A, device=device), dim=1)[0] 40 | v = torch.rand(Bv, V, device=device) 41 | out_np = numpy_searchsorted(a.cpu().numpy(), v.cpu().numpy(), 42 | side=side) 43 | out = searchsorted(a, v, side=side).cpu().numpy() 44 | np.testing.assert_array_equal(out, out_np) 45 | -------------------------------------------------------------------------------- /lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_dataset import make_data_loader 2 | -------------------------------------------------------------------------------- /lib/datasets/collate_batch.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataloader import default_collate 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def meta_anisdf_collator(batch): 7 | batch = [[default_collate([b]) for b in batch_] for batch_ in batch] 8 | return batch 9 | 10 | 11 | _collators = {'meta_anisdf': meta_anisdf_collator} 12 | 13 | 14 | def make_collator(cfg, split): 15 | collator = getattr(cfg, split).collator 16 | 17 | if collator in _collators: 18 | return _collators[collator] 19 | else: 20 | return default_collate 21 | -------------------------------------------------------------------------------- /lib/datasets/make_dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from .transforms import make_transforms 3 | from . import samplers 4 | import torch 5 | import torch.utils.data 6 | import importlib 7 | import os 8 | from .collate_batch import make_collator 9 | import numpy as np 10 | import time 11 | from lib.config.config import cfg 12 | from termcolor import colored 13 | 14 | torch.multiprocessing.set_sharing_strategy('file_system') 15 | 16 | 17 | def _dataset_factory(split): 18 | splitcfg = getattr(cfg, split) 19 | if hasattr(splitcfg, "dataset_module") and hasattr(splitcfg, "dataset_kwargs"): 20 | module = splitcfg.dataset_module 21 | args = splitcfg.dataset_kwargs 22 | else: 23 | module = getattr(cfg, split + "_dataset_module") 24 | args = getattr(cfg, split + "_dataset") 25 | 26 | dataset = importlib.import_module(module).Dataset(**args) 27 | return dataset 28 | 29 | 30 | def make_dataset(cfg, dataset_name, transforms, split='train'): 31 | dataset = _dataset_factory(split) 32 | return dataset 33 | 34 | 35 | def make_data_sampler(dataset, shuffle, is_distributed, split): 36 | if split == 'train': 37 | if is_distributed: 38 | return samplers.DistributedSampler(dataset, shuffle=shuffle) 39 | if shuffle: 40 | sampler = torch.utils.data.sampler.RandomSampler(dataset) 41 | else: 42 | sampler = torch.utils.data.sampler.SequentialSampler(dataset) 43 | elif split == 'test': 44 | if cfg.test.sampler == 'FrameSampler': 45 | sampler = samplers.FrameSampler(dataset, cfg.test.frame_sampler_interval) 46 | return sampler 47 | if is_distributed: 48 | return samplers.DistributedSampler(dataset, shuffle=shuffle) 49 | if shuffle: 50 | sampler = torch.utils.data.sampler.RandomSampler(dataset) 51 | else: 52 | sampler = torch.utils.data.sampler.SequentialSampler(dataset) 53 | elif split == 'prune': 54 | sampler = samplers.FrameSampler(dataset, cfg.prune.frame_sampler_interval) 55 | elif split == 'val' and not cfg.record_demo: 56 | sampler = samplers.FrameSampler(dataset, cfg.val.frame_sampler_interval) 57 | elif split == 'val' and cfg.record_demo: 58 | sampler = torch.utils.data.sampler.SequentialSampler(dataset) 59 | elif split == 'bullet': 60 | sampler = torch.utils.data.sampler.SequentialSampler(dataset) 61 | elif split == 'tmesh': 62 | sampler = samplers.FrameSampler(dataset, cfg.tmesh.frame_sampler_interval) 63 | elif split == 'tdmesh': 64 | sampler = samplers.FrameSampler(dataset, cfg.tdmesh.frame_sampler_interval) 65 | else: 66 | raise NotImplementedError 67 | 68 | return sampler 69 | 70 | 71 | def make_batch_data_sampler(cfg, sampler, batch_size, drop_last, max_iter, split): 72 | batch_sampler = getattr(cfg, split).batch_sampler 73 | sampler_meta = getattr(cfg, split).sampler_meta 74 | 75 | if batch_sampler == 'default': 76 | batch_sampler = torch.utils.data.sampler.BatchSampler( 77 | sampler, batch_size, drop_last) 78 | elif batch_sampler == 'image_size': 79 | batch_sampler = samplers.ImageSizeBatchSampler(sampler, batch_size, 80 | drop_last, sampler_meta) 81 | 82 | if max_iter != -1: 83 | batch_sampler = samplers.IterationBasedBatchSampler(batch_sampler, max_iter) 84 | 85 | return batch_sampler 86 | 87 | 88 | cv2.setNumThreads(1) # MARK: OpenCV undistort is why all cores are taken 89 | 90 | 91 | def worker_init_fn(worker_id): 92 | cv2.setNumThreads(1) # MARK: OpenCV undistort is why all cores are taken 93 | # previous randomness issue might just come from here 94 | if cfg.fix_random: 95 | np.random.seed(worker_id) 96 | else: 97 | np.random.seed(worker_id + (int(round(time.time() * 1000) % (2**16)))) 98 | 99 | 100 | def make_data_loader(cfg, split='train', is_distributed=False, max_iter=-1): 101 | batch_size = getattr(cfg, split).batch_size 102 | dataset_name = getattr(cfg, split).dataset 103 | if split == 'train': 104 | # shuffle = True 105 | shuffle = cfg.train.shuffle 106 | drop_last = False 107 | elif split == 'test' or split == 'prune' or split == 'val' or split == 'tmesh' or split == 'tdmesh' or split == 'bullet': 108 | shuffle = True if is_distributed else False 109 | drop_last = False 110 | else: 111 | raise NotImplementedError 112 | 113 | transforms = make_transforms(cfg, split) 114 | dataset = make_dataset(cfg, dataset_name, transforms, split) 115 | sampler = make_data_sampler(dataset, shuffle, is_distributed, split) 116 | batch_sampler = make_batch_data_sampler(cfg, sampler, batch_size, 117 | drop_last, max_iter, split) 118 | num_workers = cfg.train.num_workers 119 | if cfg.record_demo and split == 'val': 120 | num_workers = 0 121 | collator = make_collator(cfg, split) 122 | data_loader = torch.utils.data.DataLoader(dataset, 123 | batch_sampler=batch_sampler, 124 | num_workers=num_workers, 125 | collate_fn=collator, 126 | worker_init_fn=worker_init_fn, 127 | pin_memory=True, 128 | prefetch_factor=2) 129 | 130 | return data_loader 131 | -------------------------------------------------------------------------------- /lib/datasets/samplers.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | from torch.utils.data.sampler import BatchSampler 3 | import numpy as np 4 | import torch 5 | import math 6 | import torch.distributed as dist 7 | from lib.config import cfg 8 | 9 | 10 | class ImageSizeBatchSampler(Sampler): 11 | def __init__(self, sampler, batch_size, drop_last, sampler_meta): 12 | self.sampler = sampler 13 | self.batch_size = batch_size 14 | self.drop_last = drop_last 15 | self.strategy = sampler_meta.strategy 16 | self.hmin, self.wmin = sampler_meta.min_hw 17 | self.hmax, self.wmax = sampler_meta.max_hw 18 | self.divisor = 32 19 | if cfg.fix_random: 20 | np.random.seed(0) 21 | 22 | def generate_height_width(self): 23 | if self.strategy == 'origin': 24 | return -1, -1 25 | h = np.random.randint(self.hmin, self.hmax + 1) 26 | w = np.random.randint(self.wmin, self.wmax + 1) 27 | h = (h | (self.divisor - 1)) + 1 28 | w = (w | (self.divisor - 1)) + 1 29 | return h, w 30 | 31 | def __iter__(self): 32 | batch = [] 33 | h, w = self.generate_height_width() 34 | for idx in self.sampler: 35 | batch.append((idx, h, w)) 36 | if len(batch) == self.batch_size: 37 | h, w = self.generate_height_width() 38 | yield batch 39 | batch = [] 40 | if len(batch) > 0 and not self.drop_last: 41 | yield batch 42 | 43 | def __len__(self): 44 | if self.drop_last: 45 | return len(self.sampler) // self.batch_size 46 | else: 47 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 48 | 49 | 50 | class IterationBasedBatchSampler(BatchSampler): 51 | """ 52 | Wraps a BatchSampler, resampling from it until 53 | a specified number of iterations have been sampled 54 | """ 55 | 56 | def __init__(self, batch_sampler, num_iterations, start_iter=0): 57 | self.batch_sampler = batch_sampler 58 | self.sampler = self.batch_sampler.sampler 59 | self.num_iterations = num_iterations 60 | self.start_iter = start_iter 61 | 62 | def __iter__(self): 63 | iteration = self.start_iter 64 | while iteration <= self.num_iterations: 65 | for batch in self.batch_sampler: 66 | iteration += 1 67 | if iteration > self.num_iterations: 68 | break 69 | yield batch 70 | 71 | def __len__(self): 72 | return self.num_iterations 73 | 74 | 75 | class DistributedSampler(Sampler): 76 | """Sampler that restricts data loading to a subset of the dataset. 77 | It is especially useful in conjunction with 78 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 79 | process can pass a DistributedSampler instance as a DataLoader sampler, 80 | and load a subset of the original dataset that is exclusive to it. 81 | .. note:: 82 | Dataset is assumed to be of constant size. 83 | Arguments: 84 | dataset: Dataset used for sampling. 85 | num_replicas (optional): Number of processes participating in 86 | distributed training. 87 | rank (optional): Rank of the current process within num_replicas. 88 | """ 89 | 90 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 91 | if num_replicas is None: 92 | if not dist.is_available(): 93 | raise RuntimeError("Requires distributed package to be available") 94 | num_replicas = dist.get_world_size() 95 | if rank is None: 96 | if not dist.is_available(): 97 | raise RuntimeError("Requires distributed package to be available") 98 | rank = dist.get_rank() 99 | self.dataset = dataset 100 | self.num_replicas = num_replicas 101 | self.rank = rank 102 | self.epoch = 0 103 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 104 | self.total_size = self.num_samples * self.num_replicas 105 | self.shuffle = shuffle 106 | 107 | def __iter__(self): 108 | if self.shuffle: 109 | # deterministically shuffle based on epoch 110 | g = torch.Generator() 111 | g.manual_seed(self.epoch) 112 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 113 | else: 114 | indices = torch.arange(len(self.dataset)).tolist() 115 | 116 | # add extra samples to make it evenly divisible 117 | indices += indices[: (self.total_size - len(indices))] 118 | assert len(indices) == self.total_size 119 | 120 | # subsample 121 | offset = self.num_samples * self.rank 122 | indices = indices[offset:offset+self.num_samples] 123 | assert len(indices) == self.num_samples 124 | 125 | return iter(indices) 126 | 127 | def __len__(self): 128 | return self.num_samples 129 | 130 | def set_epoch(self, epoch): 131 | self.epoch = epoch 132 | 133 | 134 | class FrameSampler(Sampler): 135 | """Sampler certain frames for test 136 | """ 137 | 138 | def __init__(self, dataset, frame_sampler_interval): 139 | inds = np.arange(0, len(dataset.ims)) 140 | ni = len(dataset.ims) // dataset.num_cams 141 | inds = inds.reshape(ni, -1)[::frame_sampler_interval] 142 | self.inds = inds.ravel() 143 | 144 | def __iter__(self): 145 | return iter(self.inds) 146 | 147 | def __len__(self): 148 | return len(self.inds) 149 | -------------------------------------------------------------------------------- /lib/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | class Compose(object): 2 | def __init__(self, transforms): 3 | self.transforms = transforms 4 | 5 | def __call__(self, img, kpts=None): 6 | for t in self.transforms: 7 | img, kpts = t(img, kpts) 8 | if kpts is None: 9 | return img 10 | else: 11 | return img, kpts 12 | 13 | def __repr__(self): 14 | format_string = self.__class__.__name__ + "(" 15 | for t in self.transforms: 16 | format_string += "\n" 17 | format_string += " {0}".format(t) 18 | format_string += "\n)" 19 | return format_string 20 | 21 | 22 | class ToTensor(object): 23 | def __call__(self, img, kpts): 24 | return img / 255., kpts 25 | 26 | 27 | class Normalize(object): 28 | def __init__(self, mean, std): 29 | self.mean = mean 30 | self.std = std 31 | 32 | def __call__(self, img, kpts): 33 | img -= self.mean 34 | img /= self.std 35 | return img, kpts 36 | 37 | 38 | def make_transforms(cfg, split): 39 | # TODO 这里其实没用... 40 | if split == 'train': 41 | transform = Compose( 42 | [ 43 | ToTensor(), 44 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 45 | ] 46 | ) 47 | elif split == 'test' or split == 'prune' or split == 'val' or split == 'tmesh' or split == 'tdmesh' or split == 'bullet': 48 | transform = Compose( 49 | [ 50 | ToTensor(), 51 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 52 | ] 53 | ) 54 | else: 55 | raise NotImplementedError 56 | 57 | return transform 58 | -------------------------------------------------------------------------------- /lib/evaluators/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_evaluator import make_evaluator 2 | -------------------------------------------------------------------------------- /lib/evaluators/if_nerf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | import lpips as lp 6 | from termcolor import cprint, colored 7 | 8 | try: 9 | from skimage.measure import compare_ssim 10 | except: 11 | from skimage.metrics import structural_similarity as compare_ssim 12 | 13 | from lib.config import cfg 14 | from lib.utils.blend_utils import partnames 15 | 16 | 17 | class Evaluator: 18 | def __init__(self): 19 | self.mse = [] 20 | self.psnr = [] 21 | self.ssim = [] 22 | self.lpips = [] 23 | self.loss_fn = lp.LPIPS(net='vgg', verbose=False).cuda() 24 | self.loss_fn.eval() 25 | for p in self.loss_fn.parameters(): 26 | p.requires_grad_(False) 27 | 28 | def psnr_metric(self, img_pred, img_gt): 29 | mse = np.mean((img_pred - img_gt)**2) 30 | psnr = -10 * np.log(mse) / np.log(10) 31 | return psnr 32 | 33 | def ssim_metric(self, rgb_pred, rgb_gt, batch, epoch=-1): 34 | mask_at_box = batch['mask_at_box'][0].detach().cpu().numpy() 35 | H, W = batch['H'].item(), batch['W'].item() 36 | mask_at_box = mask_at_box.reshape(H, W) 37 | 38 | # convert the pixels into an image 39 | if cfg.white_bkgd: 40 | img_pred = np.ones((H, W, 3)) 41 | img_gt = np.ones((H, W, 3)) 42 | else: 43 | img_pred = np.zeros((H, W, 3)) 44 | img_gt = np.zeros((H, W, 3)) 45 | 46 | img_pred[mask_at_box] = rgb_pred 47 | img_gt[mask_at_box] = rgb_gt 48 | 49 | orig_img_pred = img_pred.copy() 50 | orig_img_gt = img_gt.copy() 51 | 52 | if 'crop_bbox' in batch: 53 | img_pred = fill_image(img_pred, batch) 54 | img_gt = fill_image(img_gt, batch) 55 | 56 | if epoch != -1: 57 | result_dir = os.path.join(cfg.result_dir, f'comparison_epoch{epoch}') 58 | else: 59 | result_dir = os.path.join(cfg.result_dir, 'comparison') 60 | os.system('mkdir -p {}'.format(result_dir)) 61 | frame_index = batch['frame_index'].item() 62 | view_index = batch['cam_ind'].item() 63 | cv2.imwrite( 64 | '{}/frame{:04d}_view{:04d}.png'.format(result_dir, frame_index, 65 | view_index), 66 | (img_pred[..., [2, 1, 0]] * 255)) 67 | cv2.imwrite( 68 | '{}/frame{:04d}_view{:04d}_gt.png'.format(result_dir, frame_index, 69 | view_index), 70 | (img_gt[..., [2, 1, 0]] * 255)) 71 | 72 | # crop the object region 73 | x, y, w, h = cv2.boundingRect(mask_at_box.astype(np.uint8)) 74 | img_pred = orig_img_pred[y:y + h, x:x + w] 75 | img_gt = orig_img_gt[y:y + h, x:x + w] 76 | # compute the ssim 77 | ssim = compare_ssim(img_pred, img_gt, multichannel=True) 78 | 79 | return ssim 80 | 81 | def evaluate(self, output, batch, epoch=-1): 82 | rgb_pred = output['rgb_map'][0].detach().cpu().numpy() 83 | rgb_gt = batch['rgb'][0].detach().cpu().numpy() 84 | 85 | if cfg.test_full: 86 | mask_at_box = batch['mask_at_box'][0].detach().cpu().numpy() 87 | H, W = batch['H'].item(), batch['W'].item() 88 | mask_at_box = mask_at_box.reshape(H, W) 89 | 90 | if cfg.white_bkgd: 91 | img_pred = np.ones((H, W, 3)) 92 | img_gt = np.ones((H, W, 3)) 93 | else: 94 | img_pred = np.zeros((H, W, 3)) 95 | img_gt = np.zeros((H, W, 3)) 96 | 97 | img_pred[mask_at_box] = rgb_pred 98 | img_gt[mask_at_box] = rgb_gt 99 | 100 | if cfg.eval_part != "": 101 | msk = batch['sem_mask'][partnames.index(cfg.eval_part)] 102 | img_pred[~msk] = 0 103 | img_gt[~msk] = 0 104 | 105 | frame_index = batch['frame_index'].item() 106 | view_index = batch['cam_ind'].item() 107 | if epoch != -1: 108 | result_dir = os.path.join(cfg.result_dir, f'comparison_epoch{epoch}') 109 | else: 110 | result_dir = os.path.join(cfg.result_dir, 'comparison') 111 | 112 | if not cfg.fast_eval: 113 | if not os.path.exists(result_dir): 114 | os.mkdir(result_dir) 115 | cv2.imwrite('{}/frame{:04d}_view{:04d}.png'.format(result_dir, frame_index, view_index), (img_pred[..., [2, 1, 0]] * 255)) 116 | cv2.imwrite('{}/frame{:04d}_view{:04d}_gt.png'.format(result_dir, frame_index, view_index), (img_gt[..., [2, 1, 0]] * 255)) 117 | 118 | if cfg.dry_run: 119 | return 120 | 121 | mse = np.mean((img_pred - img_gt)**2) 122 | self.mse.append(mse) 123 | 124 | psnr = self.psnr_metric(img_pred.reshape(-1, 3), img_gt.reshape(-1, 3)) 125 | self.psnr.append(psnr) 126 | 127 | lpips = self.loss_fn( 128 | torch.tensor(img_pred.transpose((2, 0, 1)), dtype=torch.float, device='cuda')[None], 129 | torch.tensor(img_gt.transpose((2, 0, 1)), dtype=torch.float, device='cuda')[None] 130 | )[0].detach().cpu().numpy() 131 | self.lpips.append(lpips) 132 | 133 | breakpoint() 134 | # ssim = self.ssim_metric(rgb_pred, rgb_gt, batch) 135 | ssim = compare_ssim(img_pred, img_gt, channel_axis=2) 136 | self.ssim.append(ssim) 137 | 138 | # print(f"mse: {mse}") 139 | # print(f"psnr: {psnr}") 140 | # print(f"ssim: {ssim}") 141 | # print(f"lpips: {lpips}") 142 | else: 143 | if rgb_gt.sum() == 0: 144 | return 145 | 146 | mse = np.mean((rgb_pred - rgb_gt)**2) 147 | self.mse.append(mse) 148 | 149 | psnr = self.psnr_metric(rgb_pred, rgb_gt) 150 | self.psnr.append(psnr) 151 | 152 | ssim = self.ssim_metric(rgb_pred, rgb_gt, batch, epoch) 153 | self.ssim.append(ssim) 154 | 155 | def summarize(self, epoch=-1): 156 | if cfg.fast_eval: 157 | cprint('WARNING: only saving evaluation metrics, no images will be saved!', color='red', attrs=['bold', 'blink']) 158 | 159 | if cfg.dry_run: 160 | return None 161 | 162 | result_dir = cfg.result_dir 163 | print( 164 | colored('the results are saved at {}'.format(result_dir), 165 | 'yellow')) 166 | 167 | if epoch == -1: 168 | result_path = os.path.join(cfg.result_dir, 'metrics.npy') 169 | else: 170 | result_path = os.path.join(cfg.result_dir, 'metrics_epoch{}.npy'.format(epoch)) 171 | 172 | os.system('mkdir -p {}'.format(os.path.dirname(result_path))) 173 | metrics = {'mse': self.mse, 'psnr': self.psnr, 'ssim': self.ssim, 'lpips': self.lpips} 174 | np.save(result_path, metrics) 175 | 176 | ret = {} 177 | print('mse: {}'.format(np.mean(self.mse))) 178 | print('psnr: {}'.format(np.mean(self.psnr))) 179 | print('ssim: {}'.format(np.mean(self.ssim))) 180 | print('lpips: {}'.format(np.mean(self.lpips))) 181 | 182 | ret.update({"psnr": np.mean(self.psnr), "ssim": np.mean(self.ssim), "lpips": np.mean(self.lpips)}) 183 | 184 | self.mse = [] 185 | self.psnr = [] 186 | self.ssim = [] 187 | self.lpips = [] 188 | 189 | return ret 190 | 191 | 192 | def fill_image(img, batch): 193 | orig_H, orig_W = batch['orig_H'].item(), batch['orig_W'].item() 194 | full_img = np.zeros((orig_H, orig_W, 3)) 195 | bbox = batch['crop_bbox'][0].detach().cpu().numpy() 196 | height = bbox[1, 1] - bbox[0, 1] 197 | width = bbox[1, 0] - bbox[0, 0] 198 | full_img[bbox[0, 1]:bbox[1, 1], 199 | bbox[0, 0]:bbox[1, 0]] = img[:height, :width] 200 | return full_img 201 | -------------------------------------------------------------------------------- /lib/evaluators/make_evaluator.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | 5 | def _evaluator_factory(cfg): 6 | module = cfg.evaluator_module 7 | evaluator = importlib.import_module(module).Evaluator() 8 | return evaluator 9 | 10 | 11 | def make_evaluator(cfg): 12 | if cfg.skip_eval: 13 | return None 14 | else: 15 | return _evaluator_factory(cfg) 16 | -------------------------------------------------------------------------------- /lib/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_network import make_network 2 | -------------------------------------------------------------------------------- /lib/networks/bw_deform/part_base_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from lib.config import cfg 6 | from lib.networks.make_network import make_viewdir_embedder, make_residual, make_part_color_network, make_part_embedder, make_deformer 7 | from lib.networks.embedders.part_base_embedder import Embedder as HashEmbedder 8 | from lib.networks.embedders.freq_embedder import Embedder as FreqEmbedder 9 | 10 | 11 | class MLP(nn.Module): 12 | def __init__(self, indim=16, outdim=3, d_hidden=64, n_layers=2): 13 | super(MLP, self).__init__() 14 | self.indim = indim 15 | self.outdim = outdim 16 | self.linears = nn.ModuleList([nn.Linear(indim, d_hidden)] + [nn.Linear(d_hidden, d_hidden) for i in range(n_layers - 1)] + [nn.Linear(d_hidden, outdim)]) 17 | self.actvn = nn.Softplus() 18 | 19 | def forward(self, input): 20 | net = input 21 | for i, l in enumerate(self.linears[:-1]): 22 | net = self.actvn(l(net)) 23 | net = self.linears[-1](net) 24 | return net 25 | 26 | 27 | ColorNetwork = MLP 28 | 29 | 30 | class Network(nn.Module): 31 | def __init__(self, partname, pid): 32 | super().__init__() 33 | self.pid = pid 34 | self.partname = partname 35 | 36 | self.embedder: HashEmbedder = make_part_embedder(cfg, partname, pid) 37 | self.embedder_dir: FreqEmbedder = make_viewdir_embedder(cfg) 38 | self.occ = MLP(self.embedder.out_dim, 1 + cfg.geo_feature_dim, **cfg.network.occ) 39 | indim_rgb = self.embedder.out_dim + self.embedder_dir.out_dim + cfg.geo_feature_dim + cfg.latent_code_dim 40 | self.rgb_latent = nn.Parameter(torch.zeros(cfg.num_latent_code, cfg.latent_code_dim)) 41 | nn.init.kaiming_normal_(self.rgb_latent) 42 | self.rgb = make_part_color_network(cfg, partname, indim=indim_rgb) 43 | 44 | def forward(self, tpts: torch.Tensor, viewdir: torch.Tensor, dists: torch.Tensor, batch): 45 | # tpts: N, 3 46 | # viewdir: N, 3 47 | N, D = tpts.shape 48 | C, L = self.rgb_latent.shape 49 | embedded = self.embedder(tpts, batch) # embedding 50 | hidden: torch.Tensor = self.occ(embedded) # networking 51 | occ = 1 - torch.exp(-self.occ.actvn(hidden[..., :1])) # activation 52 | feature = hidden[..., 1:] 53 | 54 | embedded_dir = self.embedder_dir(viewdir, batch) # embedding 55 | latent_code = self.rgb_latent.gather(dim=0, index=batch['latent_index'].expand(N, L)) # NOTE: ignoring batch dimension 56 | input = torch.cat([embedded, embedded_dir, feature, latent_code], dim=-1) 57 | rgb: torch.Tensor = self.rgb(input) # networking 58 | rgb = rgb.sigmoid() # activation 59 | 60 | raw = torch.cat([rgb, occ], dim=-1) 61 | ret = {'raw': raw, 'occ': occ} 62 | 63 | return ret 64 | -------------------------------------------------------------------------------- /lib/networks/deformers/uv_deformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from lib.utils.blend_utils import pts_sample_uv 4 | from lib.networks.make_network import make_embedder 5 | 6 | 7 | class Deformer(nn.Module): 8 | """ 9 | Deform using array 10 | """ 11 | 12 | def __init__(self, deformer_cfg) -> None: 13 | super().__init__() 14 | self.embedder = make_embedder(deformer_cfg) 15 | self.mlp = nn.Sequential( 16 | nn.Linear(self.embedder.out_dim, 32), 17 | nn.Softplus(), 18 | nn.Linear(32, 32), 19 | nn.Softplus(), 20 | nn.Linear(32, 3), 21 | ) 22 | 23 | def forward(self, xyz: torch.Tensor, batch, flag: torch.Tensor = None): 24 | if flag is not None: 25 | B, NP, _ = xyz.shape 26 | # flag: B, N 27 | assert B == 1 28 | ret = torch.zeros(B, NP, 3, device=xyz.device, dtype=xyz.dtype) 29 | inds = flag[0].nonzero(as_tuple=True)[0][:, None].expand(-1, 3) 30 | xyz = xyz[0].gather(dim=0, index=inds) 31 | 32 | uv = pts_sample_uv(xyz, batch['tuv'], batch['tbounds'], mode='bilinear') # uv: B, 2, N 33 | uv = uv.permute(0, 2, 1) # B, N, 2 34 | uv = uv.view(-1, uv.shape[-1]) # B*N, 2 35 | t = batch['frame_dim'].expand(uv.shape[0], -1).float() 36 | uvt = torch.cat([uv, t], dim=-1) # B*N, 3 37 | feat = self.embedder(uvt, batch) 38 | resd = self.mlp(feat) 39 | resd_tan = 0.05 * torch.tanh(resd) # B*N, 3 40 | 41 | if flag is not None: 42 | ret[0, inds[:, 0]] = resd_tan.to(ret.dtype, non_blocking=True) # ignoring batch dimension 43 | return ret 44 | else: 45 | return resd_tan 46 | -------------------------------------------------------------------------------- /lib/networks/embedder.py: -------------------------------------------------------------------------------- 1 | from lib.config import cfg 2 | import torch 3 | import numpy as np 4 | 5 | from torch import nn 6 | from functools import lru_cache 7 | 8 | 9 | class PosEnc(nn.Module): 10 | def __init__(self, multires, periodic_fns=[torch.sin, torch.cos], retain_input=True): 11 | super(PosEnc, self).__init__() 12 | freq_bands = 2.**torch.linspace(0., multires-1, steps=multires).cuda() # (multires) 13 | freq_bands = freq_bands[..., None, None].expand(multires, len(periodic_fns), 1).clone() # (multires, 2, 1) 14 | self.freq_bands = nn.Parameter(freq_bands, requires_grad=False) 15 | # self.register_buffer('freq_bands', freq_bands) 16 | self.multires = multires 17 | self.periodic_fns = periodic_fns 18 | self.retain_input = retain_input 19 | 20 | def get_dim(self, dim): 21 | return self.freq_bands.numel() * dim + (dim if self.retain_input else 0) 22 | 23 | # FIXME: LRU_CACHE WILL MAKE YOU UNABLE TO UPDATE INPUT PARAMETER 24 | def forward(self, inputs): 25 | # inputs: B, N, 3 26 | n_b_dim = len(inputs.shape)-1 27 | dim = inputs.shape[-1] 28 | ori_inputs = inputs 29 | inputs = inputs.view(*inputs.shape[:-1], 1, 1, inputs.shape[-1]) # (B, N, 1, 1, 3) 30 | inputs = inputs * self.freq_bands[(None,)*n_b_dim] # (B, N, 1, 1, 3) * (1, 1, multires, 2, 3) -> (B, N, multires, 2, 3) 31 | inputs = torch.cat([self.periodic_fns[i](t) for i, t in enumerate(torch.split(inputs, 1, dim=-2))], dim=-2) 32 | inputs = inputs.view(*ori_inputs.shape[:-1], self.freq_bands.numel() * dim) # (B, N, embed_dim - 3?) 33 | if self.retain_input: 34 | inputs = torch.cat([ori_inputs, inputs], dim=-1) 35 | return inputs 36 | 37 | 38 | def get_embedder(multires, input_dims=3, periodic_fns=[torch.sin, torch.cos], retain_input=True): 39 | embedder = PosEnc(multires, periodic_fns=periodic_fns, retain_input=retain_input) 40 | return embedder, embedder.get_dim(input_dims) 41 | 42 | 43 | xyz_embedder, xyz_dim = get_embedder(cfg.xyz_res) 44 | view_embedder, view_dim = get_embedder(cfg.view_res) 45 | -------------------------------------------------------------------------------- /lib/networks/embedders/freq_embedder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from lib.config import cfg 3 | import torch.nn as nn 4 | 5 | class PosEnc(nn.Module): 6 | def __init__(self, multires, periodic_fns=[torch.sin, torch.cos], retain_input=True): 7 | super(PosEnc, self).__init__() 8 | freq_bands = 2.**torch.linspace(0., multires-1, steps=multires) # (multires) 9 | freq_bands = freq_bands[..., None, None].expand(multires, len(periodic_fns), 1).clone() # (multires, 2, 1) 10 | self.freq_bands = nn.Parameter(freq_bands, requires_grad=False) 11 | # self.register_buffer('freq_bands', freq_bands) 12 | self.multires = multires 13 | self.periodic_fns = periodic_fns 14 | self.retain_input = retain_input 15 | 16 | def get_dim(self, dim): 17 | return self.freq_bands.numel() * dim + (dim if self.retain_input else 0) 18 | 19 | # FIXME: LRU_CACHE WILL MAKE YOU UNABLE TO UPDATE INPUT PARAMETER 20 | def forward(self, inputs): 21 | # inputs: B, N, 3 22 | n_b_dim = len(inputs.shape)-1 23 | dim = inputs.shape[-1] 24 | ori_inputs = inputs 25 | inputs = inputs.view(*inputs.shape[:-1], 1, 1, inputs.shape[-1]) # (B, N, 1, 1, 3) 26 | inputs = inputs * self.freq_bands[(None,)*n_b_dim] # (B, N, 1, 1, 3) * (1, 1, multires, 2, 3) -> (B, N, multires, 2, 3) 27 | inputs = torch.cat([self.periodic_fns[i](t) for i, t in enumerate(torch.split(inputs, 1, dim=-2))], dim=-2) 28 | inputs = inputs.view(*ori_inputs.shape[:-1], self.freq_bands.numel() * dim) # (B, N, embed_dim - 3?) 29 | if self.retain_input: 30 | inputs = torch.cat([ori_inputs, inputs], dim=-1) 31 | return inputs 32 | 33 | def get_embedder(multires, input_dims=3, periodic_fns=[torch.sin, torch.cos], retain_input=True): 34 | embedder = PosEnc(multires, periodic_fns=periodic_fns, retain_input=retain_input) 35 | return embedder, embedder.get_dim(input_dims) 36 | 37 | class Embedder(nn.Module): 38 | def __init__(self, res, input_dims, F=2) -> None: 39 | super().__init__() 40 | self.embedder, self.out_dim = get_embedder(res, input_dims) 41 | 42 | def forward(self, x, batch): 43 | return self.embedder(x) 44 | 45 | xyz_embedder, xyz_dim = get_embedder(cfg.xyz_res) 46 | view_embedder, view_dim = get_embedder(cfg.view_res) 47 | -------------------------------------------------------------------------------- /lib/networks/make_network.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | 5 | def make_network(cfg): 6 | module = cfg.network_module 7 | network = importlib.import_module(module).Network() 8 | return network 9 | 10 | 11 | def make_part_network(gcfg, partname, pid): 12 | from lib.networks.bw_deform.part_base_network import Network 13 | cfg = getattr(gcfg.partnet, partname) 14 | module = cfg.module 15 | network: Network = importlib.import_module(module, partname).Network(partname, pid) 16 | return network 17 | 18 | 19 | def make_embedder(cfg): 20 | module = cfg.embedder.module 21 | kwargs = cfg.embedder.kwargs 22 | embedder = importlib.import_module(module).Embedder(**kwargs) 23 | return embedder 24 | 25 | 26 | def make_part_embedder(gcfg, partname, pid): 27 | this_cfg = getattr(gcfg.partnet, partname) 28 | bbox = this_cfg.bbox 29 | module = this_cfg.embedder.module 30 | kwargs = this_cfg.embedder.kwargs 31 | embedder = importlib.import_module(module).Embedder(bbox=bbox, pid=pid, partname=partname, **kwargs) 32 | return embedder 33 | 34 | 35 | def make_viewdir_embedder(cfg): 36 | module = cfg.viewdir_embedder.module 37 | kwargs = cfg.viewdir_embedder.kwargs 38 | embedder = importlib.import_module(module).Embedder(**kwargs) 39 | return embedder 40 | 41 | 42 | def make_deformer(cfg): 43 | module = cfg.tpose_deformer.module 44 | deformer = importlib.import_module(module).Deformer(deformer_cfg=cfg.tpose_deformer) 45 | return deformer 46 | 47 | 48 | def make_residual(cfg): 49 | if 'color_residual' in cfg: 50 | module = cfg.color_residual.module 51 | kwargs = cfg.color_residual.kwargs 52 | residual = importlib.import_module(module).Residual(**kwargs) 53 | return residual 54 | else: 55 | from lib.networks.residuals.zero_residual import Residual 56 | return Residual() 57 | 58 | 59 | def make_color_network(cfg, **kargs): 60 | if "color_network" in cfg: 61 | module = cfg.color_network.module 62 | kwargs = cfg.color_network.kwargs 63 | elif "network" in cfg and "color" in cfg.network: 64 | if "module" in cfg.network.color: 65 | module = cfg.network.color.module 66 | else: 67 | module = "lib.networks.bw_deform.inb_network" 68 | kwargs = cfg.network.color 69 | 70 | full_args = dict(kwargs, **kargs) 71 | color_network = importlib.import_module(module).ColorNetwork(**full_args) 72 | return color_network 73 | 74 | 75 | def make_part_color_network(gcfg, partname, **kargs): 76 | this_cfg = None 77 | try: 78 | this_cfg = getattr(gcfg.partnet, partname) 79 | assert "color_network" in this_cfg 80 | except: 81 | pass 82 | 83 | module = "lib.networks.bw_deform.part_base_network" 84 | kwargs = {} 85 | if this_cfg is not None: 86 | if "color_network" in this_cfg: 87 | if hasattr(this_cfg.color_network, 'module'): 88 | module = this_cfg.color_network.module 89 | if hasattr(this_cfg.color_network, 'kwargs'): 90 | kwargs = this_cfg.color_network.kwargs 91 | elif "network" in this_cfg and "color" in this_cfg.network: 92 | if "module" in this_cfg.network.color: 93 | module = this_cfg.network.color.module 94 | if 'kwargs' in this_cfg.network.color: 95 | kwargs = this_cfg.network.color 96 | 97 | full_args = dict(kwargs, **kargs) 98 | color_network = importlib.import_module(module).ColorNetwork(**full_args) 99 | return color_network 100 | 101 | # def _make_module_factory(cfgname, classname): 102 | # def make_sth(cfg, **kargs): 103 | # module = getattr(cfg, cfgname).module 104 | # kwargs = getattr(cfg, cfgname).kwargs 105 | # full_args = dict(kwargs, **kargs) 106 | # sth = importlib.import_module(module).__dict__[classname](**full_args) 107 | # return sth 108 | # return make_sth 109 | 110 | # possible_modules = [ 111 | # { 112 | # "cfgname": "color_network", 113 | # "classname": "ColorNetwork", 114 | # } 115 | # ] 116 | -------------------------------------------------------------------------------- /lib/networks/renderer/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_renderer import make_renderer -------------------------------------------------------------------------------- /lib/networks/renderer/make_renderer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | from termcolor import colored 4 | 5 | def make_renderer(cfg, network, vis=False, split='train'): 6 | if hasattr(getattr(cfg, split), 'renderer_module'): 7 | module = getattr(cfg, split).renderer_module 8 | renderer = importlib.import_module(module).Renderer(network) 9 | return renderer 10 | 11 | module = cfg.renderer_module 12 | if not vis: 13 | renderer = importlib.import_module(module).Renderer(network) 14 | else: 15 | renderer = importlib.import_module(cfg.renderer_vis_module).Renderer(network) 16 | return renderer 17 | -------------------------------------------------------------------------------- /lib/networks/renderer/nerf_net_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | from lib.config import cfg 4 | 5 | 6 | def raw2outputs(raw, z_vals, white_bkgd=False): 7 | """Transforms model's predictions to semantically meaningful values. 8 | Args: 9 | raw: [num_rays, num_samples along ray, 4]. Prediction from model. 10 | z_vals: [num_rays, num_samples along ray]. Integration time. 11 | Returns: 12 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. 13 | disp_map: [num_rays]. Disparity map. Inverse of depth map. 14 | acc_map: [num_rays]. Sum of weights along each ray. 15 | weights: [num_rays, num_samples]. Weights assigned to each sampled color. 16 | depth_map: [num_rays]. Estimated distance to object. 17 | """ 18 | rgb = raw[..., :-1] # [N_rays, N_samples, 3] 19 | alpha = raw[..., -1] 20 | 21 | # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True) 22 | weights = alpha * torch.cumprod( 23 | torch.cat( 24 | [torch.ones((alpha.shape[0], 1)).to(alpha), 1. - alpha + 1e-10], 25 | -1), -1)[:, :-1] 26 | rgb_map = torch.sum(weights[..., None] * rgb, -2) # [N_rays, 3] 27 | 28 | depth_map = torch.sum(weights * z_vals, -1) 29 | disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map).to(depth_map), 30 | depth_map / torch.sum(weights, -1)) 31 | acc_map = torch.sum(weights, -1) 32 | 33 | if white_bkgd: 34 | rgb_map = rgb_map + (1. - acc_map[..., None]) 35 | 36 | return rgb_map, disp_map, acc_map, weights, depth_map 37 | 38 | def raw2outputs_pixelnerf(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False): 39 | """Transforms model's predictions to semantically meaningful values. 40 | Args: 41 | raw: [num_rays, num_samples along ray, 4]. Prediction from model. 42 | z_vals: [num_rays, num_samples along ray]. Integration time. 43 | rays_d: [num_rays, 3]. Direction of each ray. 44 | Returns: 45 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. 46 | disp_map: [num_rays]. Disparity map. Inverse of depth map. 47 | acc_map: [num_rays]. Sum of weights along each ray. 48 | weights: [num_rays, num_samples]. Weights assigned to each sampled color. 49 | depth_map: [num_rays]. Estimated distance to object. 50 | """ 51 | raw2alpha = lambda raw, dists, act_fn=F.relu: 1. - torch.exp(-act_fn(raw) * 52 | dists) 53 | 54 | rgb = torch.sigmoid(raw[..., :3]) # [N_rays, N_samples, 3] 55 | 56 | if cfg.use_occupancy: 57 | alpha = raw[..., 3] 58 | else: 59 | dists = z_vals[..., 1:] - z_vals[..., :-1] 60 | dists = torch.cat([dists, dists[:, -1:]], dim=1) 61 | # dists = torch.cat( 62 | # [dists, 63 | # torch.Tensor([1e10]).expand(dists[..., :1].shape).to(dists)], 64 | # -1) # [N_rays, N_samples] 65 | 66 | dists = dists * torch.norm(rays_d[..., None, :], dim=-1) 67 | noise = 0. 68 | 69 | alpha = raw2alpha(raw[..., 3] + noise, dists) # [N_rays, N_samples] 70 | 71 | # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True) 72 | weights = alpha * torch.cumprod( 73 | torch.cat( 74 | [torch.ones((alpha.shape[0], 1)).to(alpha), 1. - alpha + 1e-10], 75 | -1), -1)[:, :-1] 76 | rgb_map = torch.sum(weights[..., None] * rgb, -2) # [N_rays, 3] 77 | 78 | depth_map = torch.sum(weights * z_vals, -1) 79 | disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map).to(depth_map), 80 | depth_map / torch.sum(weights, -1)) 81 | acc_map = torch.sum(weights, -1) 82 | 83 | if white_bkgd: 84 | rgb_map = rgb_map + (1. - acc_map[..., None]) 85 | 86 | # obtain the biggest occupancy value along each camera ray 87 | occ_map = torch.max(alpha, dim=1)[0] 88 | 89 | return rgb_map, disp_map, acc_map, weights, depth_map, occ_map 90 | 91 | 92 | # Hierarchical sampling (section 5.2) 93 | def sample_pdf(bins, weights, N_samples, det=False): 94 | from torchsearchsorted import searchsorted 95 | 96 | # Get pdf 97 | weights = weights + 1e-5 # prevent nans 98 | pdf = weights / torch.sum(weights, -1, keepdim=True) 99 | cdf = torch.cumsum(pdf, -1) 100 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], 101 | -1) # (batch, len(bins)) 102 | 103 | # Take uniform samples 104 | if det: 105 | u = torch.linspace(0., 1., steps=N_samples).to(cdf) 106 | u = u.expand(list(cdf.shape[:-1]) + [N_samples]) 107 | else: 108 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples]).to(cdf) 109 | 110 | # Invert CDF 111 | u = u.contiguous() 112 | inds = searchsorted(cdf, u, side='right') 113 | below = torch.max(torch.zeros_like(inds - 1), inds - 1) 114 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) 115 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 116 | 117 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 118 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 119 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 120 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 121 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 122 | 123 | denom = (cdf_g[..., 1] - cdf_g[..., 0]) 124 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 125 | t = (u - cdf_g[..., 0]) / denom 126 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 127 | 128 | return samples 129 | 130 | 131 | def get_intersection_mask(sdf, z_vals): 132 | """ 133 | sdf: n_batch, n_pixel, n_sample 134 | z_vals: n_batch, n_pixel, n_sample 135 | """ 136 | sign = torch.sign(sdf[..., :-1] * sdf[..., 1:]) 137 | ind = torch.min(sign * torch.arange(sign.size(2)).flip([0]).to(sign), 138 | dim=2)[1] 139 | sign = sign.min(dim=2)[0] 140 | intersection_mask = sign == -1 141 | return intersection_mask, ind 142 | 143 | 144 | def sphere_tracing(wpts, sdf, z_vals, ray_o, ray_d, decoder): 145 | """ 146 | wpts: n_point, n_sample, 3 147 | sdf: n_point, n_sample 148 | z_vals: n_point, n_sample 149 | ray_o: n_point, 3 150 | ray_d: n_point, 3 151 | """ 152 | sign = torch.sign(sdf[..., :-1] * sdf[..., 1:]) 153 | ind = torch.min(sign * torch.arange(sign.size(1)).flip([0]).to(sign), 154 | dim=1)[1] 155 | 156 | wpts_sdf = sdf[torch.arange(len(ind)), ind] 157 | wpts_start = wpts[torch.arange(len(ind)), ind] 158 | wpts_end = wpts[torch.arange(len(ind)), ind + 1] 159 | 160 | sdf_threshold = 5e-5 161 | unfinished_mask = wpts_sdf.abs() > sdf_threshold 162 | i = 0 163 | while unfinished_mask.sum() != 0 and i < 20: 164 | curr_start = wpts_start[unfinished_mask] 165 | curr_end = wpts_end[unfinished_mask] 166 | 167 | wpts_mid = (curr_start + curr_end) / 2 168 | mid_sdf = decoder(wpts_mid)[:, 0] 169 | 170 | ind_outside = mid_sdf > 0 171 | if ind_outside.sum() > 0: 172 | curr_start[ind_outside] = wpts_mid[ind_outside] 173 | 174 | ind_inside = mid_sdf < 0 175 | if ind_inside.sum() > 0: 176 | curr_end[ind_inside] = wpts_mid[ind_inside] 177 | 178 | wpts_start[unfinished_mask] = curr_start 179 | wpts_end[unfinished_mask] = curr_end 180 | wpts_sdf[unfinished_mask] = mid_sdf 181 | unfinished_mask[unfinished_mask] = (mid_sdf.abs() > 182 | sdf_threshold) | (mid_sdf < 0) 183 | 184 | i = i + 1 185 | 186 | # get intersection points 187 | mask = (wpts_sdf.abs() < sdf_threshold) * (wpts_sdf >= 0) 188 | intersection_points = wpts_start[mask] 189 | 190 | ray_o = ray_o[mask] 191 | ray_d = ray_d[mask] 192 | z_vals = (intersection_points[:, 0] - ray_o[:, 0]) / ray_d[:, 0] 193 | 194 | return intersection_points, z_vals, mask 195 | -------------------------------------------------------------------------------- /lib/networks/renderer/pose_mesh_renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from lib.config import cfg 4 | from .nerf_net_utils import * 5 | from .. import embedder 6 | from lib.utils.blend_utils import * 7 | 8 | 9 | class Renderer: 10 | def __init__(self, net): 11 | self.net = net 12 | 13 | def render(self, batch, test=False, epoch=-1): 14 | pts = batch['pts'] 15 | sh = pts.shape 16 | 17 | if epoch != -1: 18 | batch['epoch'] = epoch 19 | 20 | if 'latent_index' not in batch: 21 | batch['latent_index'] = 0 22 | 23 | # volume rendering for each pixel 24 | chunk = 4096 * 32 25 | tpts = pts.reshape(-1, 3) 26 | N = tpts.shape[0] 27 | ret_list = [] 28 | # print(ray_o.shape) 29 | # print(batch['mask_at_box'].shape) 30 | from tqdm import tqdm 31 | for i in tqdm(range(0, N, chunk)): 32 | pts = tpts[i:i + chunk] 33 | viewdir = torch.zeros_like(pts) 34 | ret = self.net(pts, viewdir, None, batch) 35 | ret_list.append({ 36 | "occ": ret['occ'][0] 37 | }) 38 | 39 | breakpoint() 40 | 41 | keys = ret_list[0].keys() 42 | ret = {k: torch.cat([r[k] for r in ret_list], dim=0) for k in keys} 43 | assert "occ" in ret.keys() and len(ret.keys()) == 1 44 | 45 | ret['occ'] = ret['occ'].view(sh[1:-1]).detach().cpu().numpy() 46 | 47 | return ret -------------------------------------------------------------------------------- /lib/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainers import make_trainer 2 | from .optimizer import make_optimizer 3 | from .scheduler import make_lr_scheduler, set_lr_scheduler 4 | from .recorder import make_recorder 5 | 6 | -------------------------------------------------------------------------------- /lib/train/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from lib.utils.optimizer.radam import RAdam 3 | from lib.config import cfg 4 | 5 | 6 | _optimizer_factory = { 7 | 'adam': torch.optim.Adam, 8 | 'radam': RAdam, 9 | 'sgd': torch.optim.SGD 10 | } 11 | 12 | 13 | def make_optimizer(cfg, net, lr=None, weight_decay=None): 14 | params = [] 15 | lr = cfg.train.lr if lr is None else lr 16 | weight_decay = cfg.train.weight_decay if weight_decay is None else weight_decay 17 | 18 | for key, value in net.named_parameters(): 19 | if not value.requires_grad: 20 | continue 21 | if 'data' in key: 22 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 23 | else: 24 | params += [{"params": [value], "lr": lr * cfg.mlp_weight_decay, "weight_decay": weight_decay}] 25 | 26 | if 'adam' in cfg.train.optim: 27 | optimizer = _optimizer_factory[cfg.train.optim](params, lr, weight_decay=weight_decay, eps=cfg.train.eps) 28 | else: 29 | optimizer = _optimizer_factory[cfg.train.optim](params, lr, momentum=0.9) 30 | 31 | return optimizer 32 | -------------------------------------------------------------------------------- /lib/train/recorder.py: -------------------------------------------------------------------------------- 1 | from collections import deque, defaultdict 2 | import torch 3 | from tensorboardX import SummaryWriter 4 | import os 5 | from lib.config.config import cfg 6 | 7 | from termcolor import colored 8 | 9 | 10 | class SmoothedValue(object): 11 | """Track a series of values and provide access to smoothed values over a 12 | window or the global series average. 13 | """ 14 | 15 | def __init__(self, window_size=20): 16 | self.deque = deque(maxlen=window_size) 17 | self.total = 0.0 18 | self.count = 0 19 | 20 | def update(self, value): 21 | self.deque.append(value) 22 | self.count += 1 23 | self.total += value 24 | 25 | @property 26 | def median(self): 27 | d = torch.tensor(list(self.deque)) 28 | return d.median().item() 29 | 30 | @property 31 | def avg(self): 32 | d = torch.tensor(list(self.deque)) 33 | return d.mean().item() 34 | 35 | @property 36 | def global_avg(self): 37 | return self.total / self.count 38 | 39 | def process_inb(img_stats): 40 | for k, vs in img_stats.items(): 41 | for i, v in enumerate(vs): 42 | if len(v.shape) == 2: 43 | v = v[..., None] 44 | vs[i] = v 45 | img_stats[k] = torch.stack(vs) 46 | return img_stats 47 | 48 | 49 | class Recorder(object): 50 | def __init__(self, cfg): 51 | if cfg.local_rank > 0: 52 | return 53 | 54 | log_dir = cfg.record_dir 55 | if not cfg.resume: 56 | print(colored('remove contents of directory %s' % log_dir, 'red')) 57 | os.system('rm -r %s/*' % log_dir) 58 | self.writer = SummaryWriter(log_dir=log_dir) 59 | 60 | # scalars 61 | self.epoch = 0 62 | self.step = 0 63 | self.loss_stats = defaultdict(SmoothedValue) 64 | 65 | # images 66 | self.image_stats = defaultdict(object) 67 | if 'process_' + cfg.task in globals(): 68 | self.processor = globals()['process_' + cfg.task] 69 | else: 70 | self.processor = None 71 | 72 | def update_loss_stats(self, loss_dict): 73 | if cfg.local_rank > 0: 74 | return 75 | for k, v in loss_dict.items(): 76 | self.loss_stats[k].update(v.detach().cpu()) 77 | 78 | def update_image_stats(self, image_stats): 79 | if cfg.local_rank > 0: 80 | return 81 | if self.processor is None: 82 | return 83 | image_stats = self.processor(image_stats) 84 | for k, v in image_stats.items(): 85 | self.image_stats[k] = v.detach().cpu() 86 | 87 | def record(self, prefix, step=-1, loss_stats=None, image_stats=None): 88 | if cfg.local_rank > 0: 89 | return 90 | 91 | pattern = prefix + '/{}' 92 | step = step if step >= 0 else self.step 93 | loss_stats = loss_stats if loss_stats else self.loss_stats 94 | 95 | for k, v in loss_stats.items(): 96 | if isinstance(v, SmoothedValue): 97 | self.writer.add_scalar(pattern.format(k), v.median, step) 98 | else: 99 | self.writer.add_scalar(pattern.format(k), v, step) 100 | 101 | if self.processor is None: 102 | return 103 | image_stats = self.processor(image_stats) if image_stats else self.image_stats 104 | for k, v in image_stats.items(): 105 | self.writer.add_images(pattern.format(k), v, step, dataformats='NHWC') 106 | 107 | def state_dict(self): 108 | if cfg.local_rank > 0: 109 | return 110 | scalar_dict = {} 111 | scalar_dict['step'] = self.step 112 | return scalar_dict 113 | 114 | def load_state_dict(self, scalar_dict): 115 | if cfg.local_rank > 0: 116 | return 117 | self.step = scalar_dict['step'] 118 | 119 | def __str__(self): 120 | if cfg.local_rank > 0: 121 | return 122 | loss_state = [] 123 | for k, v in self.loss_stats.items(): 124 | loss_state.append('{}: {:.4f}'.format(k, v.avg)) 125 | loss_state = ' '.join(loss_state) 126 | 127 | recording_state = ' '.join(['epoch: {}', 'step: {}', '{}']) 128 | return recording_state.format(self.epoch, self.step, loss_state) 129 | 130 | 131 | def make_recorder(cfg): 132 | return Recorder(cfg) 133 | -------------------------------------------------------------------------------- /lib/train/scheduler.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from lib.utils.optimizer.lr_scheduler import WarmupMultiStepLR, MultiStepLR, ExponentialLR 3 | from torch.optim.lr_scheduler import CosineAnnealingLR 4 | 5 | 6 | def make_lr_scheduler(cfg, optimizer): 7 | cfg_scheduler = cfg.train.scheduler 8 | if cfg_scheduler.type == 'multi_step': 9 | scheduler = MultiStepLR(optimizer, 10 | milestones=cfg_scheduler.milestones, 11 | gamma=cfg_scheduler.gamma) 12 | elif cfg_scheduler.type == 'exponential': 13 | scheduler = ExponentialLR(optimizer, 14 | decay_epochs=cfg_scheduler.decay_epochs, 15 | gamma=cfg_scheduler.gamma) 16 | elif cfg_scheduler.type == 'cosine': 17 | scheduler = CosineAnnealingLR(optimizer, 18 | T_max=cfg_scheduler.T_max) 19 | return scheduler 20 | 21 | 22 | def set_lr_scheduler(cfg, scheduler): 23 | cfg_scheduler = cfg.train.scheduler 24 | if cfg_scheduler.type == 'multi_step': 25 | scheduler.milestones = Counter(cfg_scheduler.milestones) 26 | elif cfg_scheduler.type == 'exponential': 27 | scheduler.decay_epochs = cfg_scheduler.decay_epochs 28 | scheduler.gamma = cfg_scheduler.gamma 29 | -------------------------------------------------------------------------------- /lib/train/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_trainer import make_trainer 2 | -------------------------------------------------------------------------------- /lib/train/trainers/crit.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | 4 | def reg(x: torch.Tensor) -> torch.Tensor: 5 | # return (x.float()**2).sum() / 2 6 | return x.norm(dim=-1).mean() 7 | 8 | def reg_raw_crit(x: torch.Tensor): 9 | n_batch, n_pts_x2, D = x.shape 10 | n_pts = n_pts_x2 // 2 11 | length = x.norm(dim=-1, keepdim=True) # length 12 | vector = x / (length + 1e-8) # vector direction (normalized to unit sphere) 13 | # loss_length = mse(length[:, n_pts:, :], length[:, :n_pts, :]) 14 | loss_vector = reg((vector[:, n_pts:, :] - vector[:, :n_pts, :])) 15 | # loss = loss_length + loss_vector 16 | loss = loss_vector 17 | return loss 18 | 19 | def sdf_mask_crit(ret, batch): 20 | msk_sdf = ret['msk_sdf'] 21 | msk_label = ret['msk_label'] 22 | 23 | alpha = 50 24 | alpha_factor = 2 25 | alpha_milestones = [10000, 20000, 30000, 40000, 50000] 26 | for milestone in alpha_milestones: 27 | if batch['iter_step'] > milestone: 28 | alpha = alpha * alpha_factor 29 | 30 | msk_sdf = -alpha * msk_sdf 31 | mask_loss = F.binary_cross_entropy_with_logits(msk_sdf, msk_label) / alpha 32 | 33 | return mask_loss 34 | 35 | 36 | def elastic_crit(ret, batch): 37 | """ 38 | resd_jacobian: n_batch, n_point, 3, 3 39 | """ 40 | jac = ret['resd_jacobian'] 41 | U, S, V = torch.svd(jac, compute_uv=True) 42 | log_svals = torch.log(torch.clamp(S, min=1e-6)) 43 | elastic_loss = torch.sum(log_svals**2, dim=2).mean() 44 | return elastic_loss 45 | 46 | 47 | def normal_crit(ret, batch): 48 | surf_normal_pred = ret['surf_normal'][ret['surf_mask']] 49 | surf_normal = batch['normal'][ret['surf_mask']] 50 | 51 | viewdir = batch['ray_d'][ret['surf_mask']] 52 | weights = torch.sum(-surf_normal_pred * viewdir, dim=1) 53 | weights = torch.clamp(weights, min=0, max=1)**2 54 | 55 | norm = torch.norm(surf_normal, dim=1) 56 | norm[norm < 1e-8] = 1e-8 57 | surf_normal = surf_normal / norm[..., None] 58 | 59 | surf_normal_pred[:, 1:] = surf_normal_pred[:, 1:] * -1 60 | 61 | normal_loss = torch.norm(surf_normal_pred - surf_normal, dim=1) 62 | normal_loss = (weights * normal_loss).mean() 63 | 64 | return normal_loss 65 | -------------------------------------------------------------------------------- /lib/train/trainers/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zju3dv/instant-nvr/02decc423cc882deffee053cdbdee8b70c8285ec/lib/train/trainers/loss/__init__.py -------------------------------------------------------------------------------- /lib/train/trainers/loss/fourier_loss.py: -------------------------------------------------------------------------------- 1 | from turtle import forward 2 | import torch.nn as nn 3 | import torch 4 | 5 | class FourierLoss(torch.nn.Module): 6 | """ 7 | """ 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def fft(self, img_spa): 12 | img_freq = torch.fft.fft2(img_spa) 13 | img_amp = img_freq.abs() 14 | img_angle = img_freq.angle() 15 | return img_amp, img_angle, img_freq 16 | 17 | def compute_channel(self, gt, pred): 18 | gt_amp, gt_angle, gt_freq = self.fft(gt) 19 | pred_amp, pred_angle, pred_freq = self.fft(pred) 20 | amp_loss = torch.abs(gt_amp - pred_amp).mean() 21 | angle_loss = torch.abs(gt_angle - pred_angle).mean() 22 | return amp_loss + angle_loss 23 | 24 | def forward(self, gt, pred): 25 | breakpoint() 26 | H, W, C = gt.shape[-3:] 27 | # assert pred.shape[-3:] == [H, W, C] 28 | 29 | loss = 0.0 30 | 31 | for c in range(C): 32 | gt_channel = gt[..., c] 33 | pred_channel = pred[..., c] 34 | loss += self.compute_channel(gt_channel, pred_channel) 35 | 36 | return loss / C 37 | -------------------------------------------------------------------------------- /lib/train/trainers/loss/perceptual_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models.vgg as vgg 3 | from collections import namedtuple 4 | 5 | 6 | class LossNetwork(torch.nn.Module): 7 | """Reference: 8 | https://discuss.pytorch.org/t/how-to-extract-features-of-an-image-from-a-trained-model/119/3 9 | """ 10 | 11 | def __init__(self): 12 | super(LossNetwork, self).__init__() 13 | try: 14 | from torchvision.models import VGG19_Weights 15 | self.vgg_layers = vgg.vgg19(weights=VGG19_Weights.DEFAULT).features 16 | except ImportError: 17 | self.vgg_layers = vgg.vgg19(pretrained=True).features 18 | 19 | for param in self.vgg_layers.parameters(): 20 | param.requires_grad = False 21 | ''' 22 | self.layer_name_mapping = { 23 | '3': "relu1", 24 | '8': "relu2", 25 | '17': "relu3", 26 | '26': "relu4", 27 | '35': "relu5", 28 | } 29 | ''' 30 | 31 | self.layer_name_mapping = {'3': "relu1", '8': "relu2"} 32 | 33 | def forward(self, x): 34 | output = {} 35 | for name, module in self.vgg_layers._modules.items(): 36 | x = module(x) 37 | if name in self.layer_name_mapping: 38 | output[self.layer_name_mapping[name]] = x 39 | if name == '8': 40 | break 41 | LossOutput = namedtuple("LossOutput", ["relu1", "relu2"]) 42 | return LossOutput(**output) 43 | 44 | 45 | class PerceptualLoss(torch.nn.Module): 46 | def __init__(self): 47 | super(PerceptualLoss, self).__init__() 48 | 49 | self.model = LossNetwork() 50 | self.model.cuda() 51 | self.model.eval() 52 | self.mse_loss = torch.nn.MSELoss(reduction='mean') 53 | self.l1_loss = torch.nn.L1Loss(reduction='mean') 54 | 55 | def forward(self, x, target): 56 | x_feature = self.model(x[:, 0:3, :, :]) 57 | target_feature = self.model(target[:, 0:3, :, :]) 58 | 59 | feature_loss = ( 60 | self.l1_loss(x_feature.relu1, target_feature.relu1) + 61 | self.l1_loss(x_feature.relu2, target_feature.relu2)) / 2.0 62 | 63 | l1_loss = self.l1_loss(x, target) 64 | l2_loss = self.mse_loss(x, target) 65 | 66 | loss = feature_loss + l1_loss + l2_loss 67 | 68 | return loss 69 | -------------------------------------------------------------------------------- /lib/train/trainers/loss/tv_image_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class TVImageLoss(nn.Module): 6 | """ 7 | """ 8 | def __init__(self) -> None: 9 | super().__init__() 10 | 11 | def forward(self, img_pred, img_gt, mask): 12 | # eps = 1e-1 13 | # eps = 0.0 14 | diff_x_gt = torch.square(img_gt[:-1, :, :] - img_gt[1:, :, :]) 15 | diff_y_gt = torch.square(img_gt[:, :-1, :] - img_gt[:, 1:, :]) 16 | eps_x = diff_x_gt.max() 17 | eps_y = diff_y_gt.max() 18 | diff_x = F.relu(torch.square(img_pred[:-1, :, :] - img_pred[1:, :, :]) - eps_x)[mask[:-1, :]].mean() 19 | diff_y = F.relu(torch.square(img_pred[:, :-1, :] - img_pred[:, 1:, :]) - eps_y)[mask[:, :-1]].mean() 20 | loss = (diff_x + diff_y) / 2.0 21 | return loss -------------------------------------------------------------------------------- /lib/train/trainers/make_trainer.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer 2 | import importlib 3 | 4 | def _wrapper_factory(cfg, network): 5 | module = cfg.trainer_module 6 | network_wrapper = importlib.import_module(module).NetworkWrapper(network) 7 | return network_wrapper 8 | 9 | 10 | def make_trainer(cfg, network): 11 | network = _wrapper_factory(cfg, network) 12 | return Trainer(network) 13 | 14 | def make_inner_trainer(cfg, network): 15 | breakpoint() 16 | network_wrapper = _wrapper_factory(cfg, network) 17 | return network_wrapper 18 | -------------------------------------------------------------------------------- /lib/utils/base_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os.path as osp 3 | import os 4 | import numpy as np 5 | from pathlib import Path 6 | from termcolor import colored 7 | import torch 8 | from typing import Mapping, TypeVar 9 | KT = TypeVar("KT") # key type 10 | VT = TypeVar("VT") # value type 11 | 12 | def create_dir(name: os.PathLike): 13 | Path(name).mkdir(exist_ok=True, parents=True) 14 | 15 | def create_link(src, tgt): 16 | new_link = os.path.basename(tgt) 17 | if osp.exists(src) and osp.islink(src): 18 | print("Found old latest dir link {} which link to {}, replacing it to {}".format(src, os.readlink(src), tgt)) 19 | os.unlink(src) 20 | os.symlink(new_link, src) 21 | 22 | def dump_cfg(cfg, tgt_path: os.PathLike): 23 | if os.path.exists(tgt_path): 24 | if not cfg.silent: 25 | print(colored("Hey, there exists an experiment with same name before. Please make sure you are continuing.", "green")) 26 | return 27 | create_dir(Path(tgt_path).parent) 28 | cfg_str = cfg.dump() 29 | with open(tgt_path, "w") as f: 30 | f.write(cfg_str) 31 | 32 | def git_committed(): 33 | from git import Repo 34 | modified_index = Repo('.').index.diff(None) 35 | files = [] 36 | for ind in modified_index: 37 | file = ind.a_path 38 | if not file[-4:] == 'yaml': 39 | files.append(file) 40 | return len(files) == 0 41 | 42 | def git_hash(): 43 | import git 44 | repo = git.Repo(search_parent_directories=True) 45 | sha = repo.head.object.hexsha 46 | return sha 47 | 48 | def get_time(): 49 | from datetime import datetime 50 | now = datetime.now() 51 | return '_'.join(now.__str__().split(' ')) 52 | 53 | class bcolors: 54 | HEADER = '\033[95m' 55 | OKBLUE = '\033[94m' 56 | OKCYAN = '\033[96m' 57 | OKGREEN = '\033[92m' 58 | WARNING = '\033[93m' 59 | FAIL = '\033[91m' 60 | ENDC = '\033[0m' 61 | BOLD = '\033[1m' 62 | UNDERLINE = '\033[4m' 63 | 64 | 65 | def read_pickle(pkl_path): 66 | with open(pkl_path, 'rb') as f: 67 | return pickle.load(f) 68 | 69 | 70 | def save_pickle(data, pkl_path): 71 | os.system('mkdir -p {}'.format(os.path.dirname(pkl_path))) 72 | with open(pkl_path, 'wb') as f: 73 | pickle.dump(data, f) 74 | 75 | 76 | def project(xyz, K, RT): 77 | """ 78 | xyz: [N, 3] 79 | K: [3, 3] 80 | RT: [3, 4] 81 | """ 82 | xyz = np.dot(xyz, RT[:, :3].T) + RT[:, 3:].T 83 | xyz = np.dot(xyz, K.T) 84 | xy = xyz[:, :2] / xyz[:, 2:] 85 | return xy 86 | 87 | def project_torch(xyz, K, RT): 88 | """ 89 | xyz: [N, 3] 90 | K: [3, 3] 91 | RT: [3, 4] 92 | """ 93 | xyz = torch.mm(xyz, RT[:, :3].T) + RT[:, 3:].T 94 | depth = xyz[..., -1] 95 | xyz = torch.mm(xyz, K.T) 96 | xy = xyz[:, :2] / xyz[:, 2:] 97 | return xy, depth 98 | 99 | 100 | def write_K_pose_inf(K, poses, img_root): 101 | K = K.copy() 102 | K[:2] = K[:2] * 8 103 | K_inf = os.path.join(img_root, 'Intrinsic.inf') 104 | os.system('mkdir -p {}'.format(os.path.dirname(K_inf))) 105 | with open(K_inf, 'w') as f: 106 | for i in range(len(poses)): 107 | f.write('%d\n'%i) 108 | f.write('%f %f %f\n %f %f %f\n %f %f %f\n' % tuple(K.reshape(9).tolist())) 109 | f.write('\n') 110 | 111 | pose_inf = os.path.join(img_root, 'CamPose.inf') 112 | with open(pose_inf, 'w') as f: 113 | for pose in poses: 114 | pose = np.linalg.inv(pose) 115 | A = pose[0:3,:] 116 | tmp = np.concatenate([A[0:3,2].T, A[0:3,0].T,A[0:3,1].T,A[0:3,3].T]) 117 | f.write('%f %f %f %f %f %f %f %f %f %f %f %f\n' % tuple(tmp.tolist())) 118 | 119 | def merge_dicts(dict_a, dict_b, b_append_key): 120 | dict = {} 121 | dict2 = {} 122 | for k in dict_a: 123 | dict.update({k: dict_a[k]}) 124 | dict2.update({k + b_append_key: dict_b[k]}) 125 | dict.update(dict2) 126 | return dict 127 | 128 | 129 | class DotDict(dict, Mapping[KT, VT]): 130 | """ 131 | a dictionary that supports dot notation 132 | as well as dictionary access notation 133 | usage: d = DotDict() or d = DotDict({'val1':'first'}) 134 | set attributes: d.val2 = 'second' or d['val2'] = 'second' 135 | get attributes: d.val2 or d['val2'] 136 | """ 137 | 138 | def update(self, dct=None, **kwargs): 139 | if dct is None: 140 | dct = kwargs 141 | else: 142 | dct.update(kwargs) 143 | for k, v in dct.items(): 144 | if k in self: 145 | target_type = type(self[k]) 146 | if not isinstance(v, target_type): 147 | # NOTE: bool('False') will be True 148 | if target_type == bool and isinstance(v, str): 149 | dct[k] = v == 'True' 150 | else: 151 | dct[k] = target_type(v) 152 | dict.update(self, dct) 153 | 154 | def __hash__(self): 155 | return hash(''.join([str(self.values().__hash__())])) 156 | 157 | def __init__(self, dct=None, **kwargs): 158 | if dct is None: 159 | dct = kwargs 160 | else: 161 | dct.update(kwargs) 162 | if dct is not None: 163 | for key, value in dct.items(): 164 | if hasattr(value, 'keys'): 165 | value = DotDict(value) 166 | self[key] = value 167 | 168 | """ 169 | Uncomment following lines and 170 | comment out __getattr__ = dict.__getitem__ to get feature: 171 | 172 | returns empty numpy array for undefined keys, so that you can easily copy things around 173 | TODO: potential caveat, harder to trace where this is set to np.array([], dtype=np.float32) 174 | """ 175 | 176 | def __getitem__(self, key): 177 | try: 178 | return dict.__getitem__(self, key) 179 | except KeyError as e: 180 | raise AttributeError(e) 181 | # MARK: Might encounter exception in newer version of pytorch 182 | # Traceback (most recent call last): 183 | # File "/home/xuzhen/miniconda3/envs/torch/lib/python3.9/multiprocessing/queues.py", line 245, in _feed 184 | # obj = _ForkingPickler.dumps(obj) 185 | # File "/home/xuzhen/miniconda3/envs/torch/lib/python3.9/multiprocessing/reduction.py", line 51, in dumps 186 | # cls(buf, protocol).dump(obj) 187 | # KeyError: '__getstate__' 188 | # MARK: Because you allow your __getattr__() implementation to raise the wrong kind of exception. 189 | __getattr__ = __getitem__ # overidden dict.__getitem__ 190 | # __getattr__ = dict.__getitem__ 191 | __setattr__ = dict.__setitem__ 192 | __delattr__ = dict.__delitem__ -------------------------------------------------------------------------------- /lib/utils/debug_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from termcolor import colored 3 | import subprocess 4 | import time 5 | import shutil 6 | 7 | 8 | def toc(): 9 | return time.time() * 1000 10 | 11 | def myprint(cmd, level): 12 | color = {'run': 'blue', 'info': 'green', 'warn': 'yellow', 'error': 'red'}[level] 13 | print(colored(cmd, color)) 14 | 15 | def log(text): 16 | myprint(text, 'info') 17 | 18 | def log_time(text): 19 | strf = get_time() 20 | print(colored(strf, 'yellow'), colored(text, 'green')) 21 | 22 | def mywarn(text): 23 | myprint(text, 'warn') 24 | 25 | warning_infos = set() 26 | 27 | def oncewarn(text): 28 | if text in warning_infos: 29 | return 30 | warning_infos.add(text) 31 | myprint(text, 'warn') 32 | 33 | 34 | def myerror(text): 35 | myprint(text, 'error') 36 | 37 | def run_cmd(cmd, verbo=True, bg=False): 38 | if verbo: myprint('[run] ' + cmd, 'run') 39 | if bg: 40 | args = cmd.split() 41 | print(args) 42 | p = subprocess.Popen(args) 43 | return [p] 44 | else: 45 | exit_status = os.system(cmd) 46 | if exit_status != 0: 47 | raise RuntimeError 48 | return [] 49 | 50 | def mkdir(path): 51 | if os.path.exists(path): 52 | return 0 53 | log('mkdir {}'.format(path)) 54 | os.makedirs(path, exist_ok=True) 55 | 56 | def cp(srcname, dstname): 57 | mkdir(os.join(os.path.dirname(dstname))) 58 | shutil.copyfile(srcname, dstname) 59 | 60 | def check_exists(path): 61 | flag1 = os.path.isfile(path) and os.path.exists(path) 62 | flag2 = os.path.isdir(path) and len(os.listdir(path)) >= 10 63 | return flag1 or flag2 -------------------------------------------------------------------------------- /lib/utils/if_nerf/if_nerf_net_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from lib.config import cfg 5 | import trimesh 6 | import imageio 7 | 8 | 9 | def update_loss_img(output, batch): 10 | mse = torch.mean((output['rgb_map'] - batch['rgb'])**2, dim=2)[0] 11 | mse = mse.detach().cpu().numpy().astype(np.float32) 12 | 13 | # load the loss img 14 | img_path = batch['meta']['img_path'][0] 15 | paths = img_path.split('/') 16 | paths[-1] = os.path.basename(img_path).replace('.jpg', '.npy') 17 | loss_img_path = os.path.join(paths[0], 'loss', *paths[1:]) 18 | if os.path.exists(loss_img_path): 19 | loss_img = np.load(loss_img_path) 20 | else: 21 | os.system("mkdir -p '{}'".format(os.path.dirname(loss_img_path))) 22 | H, W = int(cfg.H * cfg.ratio), int(cfg.W * cfg.ratio) 23 | loss_img = mse.mean() * np.ones([H, W]).astype(np.float32) 24 | 25 | coord = batch['img_coord'][0] 26 | coord = coord.detach().cpu().numpy() 27 | loss_img[coord[:, 0], coord[:, 1]] = mse 28 | np.save(loss_img_path, loss_img) 29 | 30 | 31 | def init_smpl(smpl): 32 | data_root = 'data/light_stage' 33 | smpl_dir = os.path.join(data_root, cfg.smpl, cfg.human) 34 | for i in range(cfg.ni): 35 | smpl_path = os.path.join(smpl_dir, '{}.ply'.format(i + 1)) 36 | ply = trimesh.load(smpl_path) 37 | xyz = np.array(ply.vertices).ravel() 38 | smpl.weight.data[i] = torch.FloatTensor(xyz) 39 | return smpl 40 | 41 | 42 | def pts_to_can_pts(pts, batch): 43 | """transform pts from the world coordinate to the smpl coordinate""" 44 | Th = batch['Th'] 45 | pts = pts - Th 46 | R = batch['R'] 47 | pts = torch.matmul(pts, batch['R']) 48 | return pts 49 | 50 | 51 | def pts_to_coords(pts, min_xyz): 52 | pts = pts.clone().detach() 53 | # convert xyz to the voxel coordinate dhw 54 | dhw = pts[..., [2, 1, 0]] 55 | min_dhw = min_xyz[:, [2, 1, 0]] 56 | dhw = dhw - min_dhw[:, None] 57 | dhw = dhw / torch.tensor(cfg.voxel_size).to(dhw) 58 | return dhw 59 | 60 | 61 | def record_mask_depth(output, batch): 62 | img_path = os.path.join(batch['data_root'][0], batch['img_name'][0]) 63 | msk_path = os.path.join('data/train_mask_depth', 'mask', 64 | img_path[:-4] + '.png') 65 | depth_path = os.path.join('data/train_mask_depth', 'depth', 66 | img_path[:-4] + '.png') 67 | 68 | max_depth = 10 69 | if os.path.exists(msk_path): 70 | msk = imageio.imread(msk_path) 71 | depth = imageio.imread(depth_path) 72 | depth = depth / 65535 * max_depth 73 | else: 74 | os.system("mkdir -p '{}'".format(os.path.dirname(msk_path))) 75 | os.system("mkdir -p '{}'".format(os.path.dirname(depth_path))) 76 | H, W = batch['H'].item(), batch['W'].item() 77 | msk = np.zeros([H, W]) 78 | depth = np.zeros([H, W]) 79 | 80 | coord = batch['coord'][0].detach().cpu().numpy() 81 | surf_z = output['surf_z'][0].detach().cpu().numpy() 82 | surf_mask = output['surf_mask'][0].detach().cpu().numpy() 83 | 84 | fg_coord = coord[surf_mask] 85 | bkgd_coord = coord[surf_mask == 0] 86 | 87 | msk[fg_coord[:, 0], fg_coord[:, 1]] = 255 88 | msk[bkgd_coord[:, 0], bkgd_coord[:, 1]] = 0 89 | msk = msk.astype(np.uint8) 90 | 91 | depth[fg_coord[:, 0], fg_coord[:, 1]] = surf_z[surf_mask] 92 | depth = depth / max_depth * 65535 93 | depth = depth.astype(np.uint16) 94 | 95 | imageio.imwrite(msk_path, msk) 96 | imageio.imwrite(depth_path, depth) 97 | -------------------------------------------------------------------------------- /lib/utils/img_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from matplotlib import cm 3 | import matplotlib.pyplot as plt 4 | import matplotlib.patches as patches 5 | import numpy as np 6 | import cv2 7 | 8 | 9 | def unnormalize_img(img, mean, std): 10 | """ 11 | img: [3, h, w] 12 | """ 13 | img = img.detach().cpu().clone() 14 | # img = img / 255. 15 | img *= torch.tensor(std).view(3, 1, 1) 16 | img += torch.tensor(mean).view(3, 1, 1) 17 | min_v = torch.min(img) 18 | img = (img - min_v) / (torch.max(img) - min_v) 19 | return img 20 | 21 | 22 | def bgr_to_rgb(img): 23 | return img[:, :, [2, 1, 0]] 24 | 25 | 26 | def horizon_concate(inp0, inp1): 27 | h0, w0 = inp0.shape[:2] 28 | h1, w1 = inp1.shape[:2] 29 | if inp0.ndim == 3: 30 | inp = np.zeros((max(h0, h1), w0 + w1, 3), dtype=inp0.dtype) 31 | inp[:h0, :w0, :] = inp0 32 | inp[:h1, w0:(w0 + w1), :] = inp1 33 | else: 34 | inp = np.zeros((max(h0, h1), w0 + w1), dtype=inp0.dtype) 35 | inp[:h0, :w0] = inp0 36 | inp[:h1, w0:(w0 + w1)] = inp1 37 | return inp 38 | 39 | 40 | def vertical_concate(inp0, inp1): 41 | h0, w0 = inp0.shape[:2] 42 | h1, w1 = inp1.shape[:2] 43 | if inp0.ndim == 3: 44 | inp = np.zeros((h0 + h1, max(w0, w1), 3), dtype=inp0.dtype) 45 | inp[:h0, :w0, :] = inp0 46 | inp[h0:(h0 + h1), :w1, :] = inp1 47 | else: 48 | inp = np.zeros((h0 + h1, max(w0, w1)), dtype=inp0.dtype) 49 | inp[:h0, :w0] = inp0 50 | inp[h0:(h0 + h1), :w1] = inp1 51 | return inp 52 | 53 | 54 | def transparent_cmap(cmap): 55 | """Copy colormap and set alpha values""" 56 | mycmap = cmap 57 | mycmap._init() 58 | mycmap._lut[:,-1] = 0.3 59 | return mycmap 60 | 61 | cmap = transparent_cmap(plt.get_cmap('jet')) 62 | 63 | 64 | def set_grid(ax, h, w, interval=8): 65 | ax.set_xticks(np.arange(0, w, interval)) 66 | ax.set_yticks(np.arange(0, h, interval)) 67 | ax.grid() 68 | ax.set_yticklabels([]) 69 | ax.set_xticklabels([]) 70 | 71 | 72 | color_list = np.array( 73 | [ 74 | 0.000, 0.447, 0.741, 75 | 0.850, 0.325, 0.098, 76 | 0.929, 0.694, 0.125, 77 | 0.494, 0.184, 0.556, 78 | 0.466, 0.674, 0.188, 79 | 0.301, 0.745, 0.933, 80 | 0.635, 0.078, 0.184, 81 | 0.300, 0.300, 0.300, 82 | 0.600, 0.600, 0.600, 83 | 1.000, 0.000, 0.000, 84 | 1.000, 0.500, 0.000, 85 | 0.749, 0.749, 0.000, 86 | 0.000, 1.000, 0.000, 87 | 0.000, 0.000, 1.000, 88 | 0.667, 0.000, 1.000, 89 | 0.333, 0.333, 0.000, 90 | 0.333, 0.667, 0.000, 91 | 0.333, 1.000, 0.000, 92 | 0.667, 0.333, 0.000, 93 | 0.667, 0.667, 0.000, 94 | 0.667, 1.000, 0.000, 95 | 1.000, 0.333, 0.000, 96 | 1.000, 0.667, 0.000, 97 | 1.000, 1.000, 0.000, 98 | 0.000, 0.333, 0.500, 99 | 0.000, 0.667, 0.500, 100 | 0.000, 1.000, 0.500, 101 | 0.333, 0.000, 0.500, 102 | 0.333, 0.333, 0.500, 103 | 0.333, 0.667, 0.500, 104 | 0.333, 1.000, 0.500, 105 | 0.667, 0.000, 0.500, 106 | 0.667, 0.333, 0.500, 107 | 0.667, 0.667, 0.500, 108 | 0.667, 1.000, 0.500, 109 | 1.000, 0.000, 0.500, 110 | 1.000, 0.333, 0.500, 111 | 1.000, 0.667, 0.500, 112 | 1.000, 1.000, 0.500, 113 | 0.000, 0.333, 1.000, 114 | 0.000, 0.667, 1.000, 115 | 0.000, 1.000, 1.000, 116 | 0.333, 0.000, 1.000, 117 | 0.333, 0.333, 1.000, 118 | 0.333, 0.667, 1.000, 119 | 0.333, 1.000, 1.000, 120 | 0.667, 0.000, 1.000, 121 | 0.667, 0.333, 1.000, 122 | 0.667, 0.667, 1.000, 123 | 0.667, 1.000, 1.000, 124 | 1.000, 0.000, 1.000, 125 | 1.000, 0.333, 1.000, 126 | 1.000, 0.667, 1.000, 127 | 0.167, 0.000, 0.000, 128 | 0.333, 0.000, 0.000, 129 | 0.500, 0.000, 0.000, 130 | 0.667, 0.000, 0.000, 131 | 0.833, 0.000, 0.000, 132 | 1.000, 0.000, 0.000, 133 | 0.000, 0.167, 0.000, 134 | 0.000, 0.333, 0.000, 135 | 0.000, 0.500, 0.000, 136 | 0.000, 0.667, 0.000, 137 | 0.000, 0.833, 0.000, 138 | 0.000, 1.000, 0.000, 139 | 0.000, 0.000, 0.167, 140 | 0.000, 0.000, 0.333, 141 | 0.000, 0.000, 0.500, 142 | 0.000, 0.000, 0.667, 143 | 0.000, 0.000, 0.833, 144 | 0.000, 0.000, 1.000, 145 | 0.000, 0.000, 0.000, 146 | 0.143, 0.143, 0.143, 147 | 0.286, 0.286, 0.286, 148 | 0.429, 0.429, 0.429, 149 | 0.571, 0.571, 0.571, 150 | 0.714, 0.714, 0.714, 151 | 0.857, 0.857, 0.857, 152 | 1.000, 1.000, 1.000, 153 | 0.50, 0.5, 0 154 | ] 155 | ).astype(np.float32) 156 | colors = color_list.reshape((-1, 3)) * 255 157 | colors = np.array(colors, dtype=np.uint8).reshape(len(colors), 1, 1, 3) 158 | 159 | def get_schp_palette(num_cls=256): 160 | # Copied from SCHP 161 | """ Returns the color map for visualizing the segmentation mask. 162 | Inputs: 163 | =num_cls= 164 | Number of classes. 165 | Returns: 166 | The color map. 167 | """ 168 | n = num_cls 169 | palette = [0] * (n * 3) 170 | for j in range(0, n): 171 | lab = j 172 | palette[j * 3 + 0] = 0 173 | palette[j * 3 + 1] = 0 174 | palette[j * 3 + 2] = 0 175 | i = 0 176 | while lab: 177 | palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) 178 | palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) 179 | palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) 180 | i += 1 181 | lab >>= 3 182 | 183 | palette = np.array(palette, dtype=np.uint8) 184 | palette = palette.reshape(-1, 3) # n_cls, 3 185 | return palette 186 | -------------------------------------------------------------------------------- /lib/utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 1 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | (_, channel, _, _) = img1.size() 49 | 50 | if channel == self.channel and self.window.data.type() == img1.data.type(): 51 | window = self.window 52 | else: 53 | window = create_window(self.window_size, channel) 54 | 55 | if img1.is_cuda: 56 | window = window.cuda(img1.get_device()) 57 | window = window.type_as(img1) 58 | 59 | self.window = window 60 | self.channel = channel 61 | 62 | 63 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 64 | 65 | def ssim(img1, img2, window_size = 11, size_average = True): 66 | (_, channel, _, _) = img1.size() 67 | window = create_window(window_size, channel) 68 | 69 | if img1.is_cuda: 70 | window = window.cuda(img1.get_device()) 71 | window = window.type_as(img1) 72 | 73 | return _ssim(img1, img2, window, window_size, channel, size_average) -------------------------------------------------------------------------------- /lib/utils/optimizer/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from bisect import bisect_right 2 | from collections import Counter 3 | 4 | import torch 5 | 6 | 7 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 8 | def __init__( 9 | self, 10 | optimizer, 11 | milestones, 12 | gamma=0.1, 13 | warmup_factor=1.0 / 3, 14 | warmup_iters=5, 15 | warmup_method="linear", 16 | last_epoch=-1, 17 | ): 18 | if not list(milestones) == sorted(milestones): 19 | raise ValueError( 20 | "Milestones should be a list of" " increasing integers. Got {}", 21 | milestones, 22 | ) 23 | 24 | if warmup_method not in ("constant", "linear"): 25 | raise ValueError( 26 | "Only 'constant' or 'linear' warmup_method accepted" 27 | "got {}".format(warmup_method) 28 | ) 29 | self.milestones = milestones 30 | self.gamma = gamma 31 | self.warmup_factor = warmup_factor 32 | self.warmup_iters = warmup_iters 33 | self.warmup_method = warmup_method 34 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 35 | 36 | def get_lr(self): 37 | warmup_factor = 1 38 | if self.last_epoch < self.warmup_iters: 39 | if self.warmup_method == "constant": 40 | warmup_factor = self.warmup_factor 41 | elif self.warmup_method == "linear": 42 | alpha = float(self.last_epoch) / self.warmup_iters 43 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 44 | return [ 45 | base_lr 46 | * warmup_factor 47 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 48 | for base_lr in self.base_lrs 49 | ] 50 | 51 | 52 | class MultiStepLR(torch.optim.lr_scheduler._LRScheduler): 53 | 54 | def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1): 55 | self.milestones = Counter(milestones) 56 | self.gamma = gamma 57 | super(MultiStepLR, self).__init__(optimizer, last_epoch) 58 | 59 | def get_lr(self): 60 | if self.last_epoch not in self.milestones: 61 | return [group['lr'] for group in self.optimizer.param_groups] 62 | return [group['lr'] * self.gamma ** self.milestones[self.last_epoch] 63 | for group in self.optimizer.param_groups] 64 | 65 | 66 | class ExponentialLR(torch.optim.lr_scheduler._LRScheduler): 67 | 68 | def __init__(self, optimizer, decay_epochs, gamma=0.1, last_epoch=-1): 69 | self.decay_epochs = decay_epochs 70 | self.gamma = gamma 71 | super(ExponentialLR, self).__init__(optimizer, last_epoch) 72 | 73 | def get_lr(self): 74 | return [base_lr * self.gamma ** (self.last_epoch / self.decay_epochs) 75 | for base_lr in self.base_lrs] 76 | -------------------------------------------------------------------------------- /lib/utils/render_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import os 4 | import cv2 5 | 6 | from lib.config import cfg 7 | 8 | from lib.utils.if_nerf import if_nerf_data_utils as if_nerf_dutils 9 | 10 | 11 | def normalize(x): 12 | return x / np.linalg.norm(x) 13 | 14 | 15 | def viewmatrix(z, up, pos): 16 | vec2 = normalize(z) 17 | vec0_avg = up 18 | vec1 = normalize(np.cross(vec2, vec0_avg)) 19 | vec0 = normalize(np.cross(vec1, vec2)) 20 | m = np.stack([vec0, vec1, vec2, pos], 1) 21 | return m 22 | 23 | 24 | def ptstocam(pts, c2w): 25 | tt = np.matmul(c2w[:3, :3].T, (pts-c2w[:3, 3])[..., np.newaxis])[..., 0] 26 | return tt 27 | 28 | 29 | def load_cam(ann_file): 30 | if ann_file.endswith('.json'): 31 | annots = json.load(open(ann_file, 'r')) 32 | cams = annots['cams']['20190823'] 33 | else: 34 | annots = np.load(ann_file, allow_pickle=True).item() 35 | cams = annots['cams'] 36 | 37 | K = [] 38 | RT = [] 39 | lower_row = np.array([[0., 0., 0., 1.]]) 40 | 41 | for i in range(len(cams['K'])): 42 | K.append(np.array(cams['K'][i])) 43 | K[i][:2] = K[i][:2] * cfg.ratio 44 | 45 | r = np.array(cams['R'][i]) 46 | t = np.array(cams['T'][i]) / 1000. 47 | r_t = np.concatenate([r, t], 1) 48 | RT.append(np.concatenate([r_t, lower_row], 0)) 49 | 50 | return K, RT 51 | 52 | 53 | def get_center_rayd(K, RT): 54 | H, W = int(cfg.H * cfg.ratio), int(cfg.W * cfg.ratio) 55 | RT = np.array(RT) 56 | ray_o, ray_d = if_nerf_dutils.get_rays(H, W, K, 57 | RT[:3, :3], RT[:3, 3]) 58 | return ray_d[H // 2, W // 2] 59 | 60 | 61 | def gen_path(RT, center=None): 62 | breakpoint() 63 | lower_row = np.array([[0., 0., 0., 1.]]) 64 | 65 | # transfer RT to camera_to_world matrix 66 | RT = np.array(RT) 67 | RT[:] = np.linalg.inv(RT[:]) 68 | 69 | RT = np.concatenate([RT[:, :, 1:2], RT[:, :, 0:1], 70 | -RT[:, :, 2:3], RT[:, :, 3:4]], 2) 71 | 72 | up = normalize(RT[:, :3, 0].sum(0)) # average up vector 73 | z = normalize(RT[0, :3, 2]) 74 | vec1 = normalize(np.cross(z, up)) 75 | vec2 = normalize(np.cross(up, vec1)) 76 | z_off = 0 77 | 78 | if center is None: 79 | center = RT[:, :3, 3].mean(0) 80 | z_off = 1.3 81 | 82 | c2w = np.stack([up, vec1, vec2, center], 1) # a virtual camera at the center with up vector following the negative gravity direction. 83 | 84 | # get radii for spiral path 85 | tt = ptstocam(RT[:, :3, 3], c2w).T 86 | rads = np.percentile(np.abs(tt), 80, -1) 87 | rads = rads * 1.3 88 | rads = np.array(list(rads) + [1.]) 89 | 90 | render_w2c = [] 91 | for theta in np.linspace(0., 2 * np.pi, cfg.render_views + 1)[:-1]: 92 | # camera position 93 | cam_pos = np.array([0, np.sin(theta), np.cos(theta), 1] * rads) 94 | cam_pos_world = np.dot(c2w[:3, :4], cam_pos) 95 | # z axis 96 | z = normalize(cam_pos_world - 97 | np.dot(c2w[:3, :4], np.array([z_off, 0, 0, 1.]))) 98 | # vector -> 3x4 matrix (camera_to_world) 99 | mat = viewmatrix(z, up, cam_pos_world) 100 | 101 | mat = np.concatenate([mat[:, 1:2], mat[:, 0:1], 102 | -mat[:, 2:3], mat[:, 3:4]], 1) 103 | mat = np.concatenate([mat, lower_row], 0) 104 | mat = np.linalg.inv(mat) 105 | render_w2c.append(mat) 106 | 107 | return render_w2c 108 | 109 | 110 | def read_voxel(frame, args): 111 | voxel_path = os.path.join(args['data_root'], 'voxel', args['human'], 112 | '{}.npz'.format(frame)) 113 | voxel_data = np.load(voxel_path) 114 | occupancy = np.unpackbits(voxel_data['compressed_occupancies']) 115 | occupancy = occupancy.reshape(cfg.res, cfg.res, 116 | cfg.res).astype(np.float32) 117 | bounds = voxel_data['bounds'].astype(np.float32) 118 | return occupancy, bounds 119 | 120 | 121 | def image_rays(RT, K, bounds): 122 | H = cfg.H * cfg.ratio 123 | W = cfg.W * cfg.ratio 124 | ray_o, ray_d = if_nerf_dutils.get_rays(H, W, K, 125 | RT[:3, :3], RT[:3, 3]) 126 | 127 | ray_o = ray_o.reshape(-1, 3).astype(np.float32) 128 | ray_d = ray_d.reshape(-1, 3).astype(np.float32) 129 | near, far, mask_at_box = if_nerf_dutils.get_near_far(bounds, ray_o, ray_d) 130 | near = near.astype(np.float32) 131 | far = far.astype(np.float32) 132 | ray_o = ray_o[mask_at_box] 133 | ray_d = ray_d[mask_at_box] 134 | 135 | center = (bounds[0] + bounds[1]) / 2 136 | scale = np.max(bounds[1] - bounds[0]) 137 | 138 | return ray_o, ray_d, near, far, center, scale, mask_at_box 139 | 140 | 141 | def get_image_rays0(RT0, RT, K, bounds): 142 | """ 143 | Use RT to get the mask_at_box and fill this region with rays emitted from view RT0 144 | """ 145 | H = cfg.H * cfg.ratio 146 | ray_o, ray_d = if_nerf_dutils.get_rays(H, H, K, 147 | RT[:3, :3], RT[:3, 3]) 148 | 149 | ray_o = ray_o.reshape(-1, 3).astype(np.float32) 150 | ray_d = ray_d.reshape(-1, 3).astype(np.float32) 151 | near, far, mask_at_box = if_nerf_dutils.get_near_far(bounds, ray_o, ray_d) 152 | 153 | ray_o, ray_d = if_nerf_dutils.get_rays(H, H, K, 154 | RT0[:3, :3], RT0[:3, 3]) 155 | ray_d = ray_d.reshape(-1, 3).astype(np.float32) 156 | ray_d = ray_d[mask_at_box] 157 | 158 | return ray_d 159 | 160 | 161 | def save_img(img, frame_root, index, mask_at_box): 162 | H = int(cfg.H * cfg.ratio) 163 | rgb_pred = img['rgb_map'][0].detach().cpu().numpy() 164 | mask_at_box = mask_at_box.reshape(H, H) 165 | 166 | img_pred = np.zeros((H, H, 3)) 167 | img_pred[mask_at_box] = rgb_pred 168 | img_pred[:, :, [0, 1, 2]] = img_pred[:, :, [2, 1, 0]] 169 | 170 | print("saved frame %d" % index) 171 | cv2.imwrite(os.path.join(frame_root, '%d.jpg' % index), img_pred * 255) 172 | -------------------------------------------------------------------------------- /lib/visualizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_visualizer import make_visualizer 2 | -------------------------------------------------------------------------------- /lib/visualizers/if_nerf.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from lib.config import cfg 4 | import os 5 | import cv2 6 | from termcolor import colored 7 | import os.path as osp 8 | from lib.utils.base_utils import create_link, get_time 9 | from pathlib import Path 10 | 11 | 12 | class Visualizer: 13 | def __init__(self, name = None): 14 | if name is None: 15 | name = get_time() 16 | self.name = name 17 | self.result_dir = osp.join(cfg.result_dir, name) 18 | Path(self.result_dir).mkdir(exist_ok=True, parents=True) 19 | print( 20 | colored('the results are saved at {}'.format(self.result_dir), 21 | 'yellow')) 22 | 23 | def visualize_image(self, output, batch): 24 | rgb_pred = output['rgb_map'][0].detach().cpu().numpy() 25 | rgb_gt = batch['rgb'][0].detach().cpu().numpy() 26 | print('mse: {}'.format(np.mean((rgb_pred - rgb_gt)**2))) 27 | 28 | if rgb_pred.shape == (1024, 3): 29 | img_pred = rgb_pred.reshape(32, 32, 3) 30 | img_gt = rgb_gt.reshape(32, 32, 3) 31 | breakpoint() 32 | else: 33 | mask_at_box = batch['mask_at_box'][0].detach().cpu().numpy() 34 | H, W = batch['H'].item(), batch['W'].item() 35 | mask_at_box = mask_at_box.reshape(H, W) 36 | 37 | if cfg.white_bkgd: 38 | img_pred = np.ones((H, W, 3)) 39 | img_gt = np.ones((H, W, 3)) 40 | else: 41 | img_pred = np.zeros((H, W, 3)) 42 | img_gt = np.zeros((H, W, 3)) 43 | 44 | img_pred[mask_at_box] = rgb_pred 45 | img_gt[mask_at_box] = rgb_gt 46 | 47 | result_dir = os.path.join(self.result_dir, 'comparison') 48 | os.system('mkdir -p {}'.format(result_dir)) 49 | frame_index = batch['frame_index'].item() 50 | view_index = batch['cam_ind'].item() 51 | error_map = np.abs(img_pred - img_gt).sum(axis = -1) 52 | cv2.imwrite( 53 | '{}/frame{:04d}_view{:04d}.png'.format(result_dir, frame_index, 54 | view_index), 55 | (img_pred[..., [2, 1, 0]] * 255)) 56 | cv2.imwrite( 57 | '{}/frame{:04d}_view{:04d}_gt.png'.format(result_dir, frame_index, 58 | view_index), 59 | (img_gt[..., [2, 1, 0]] * 255)) 60 | cv2.imwrite("{}/frame{:04d}_view{:04d}_error.png".format(result_dir, frame_index, view_index), (error_map * 255).astype(np.uint8)) 61 | 62 | # _, (ax1, ax2) = plt.subplots(1, 2) 63 | # ax1.imshow(img_pred) 64 | # ax2.imshow(img_gt) 65 | # plt.show() 66 | 67 | def visualize_normal(self, output, batch): 68 | mask_at_box = batch['mask_at_box'][0].detach().cpu().numpy() 69 | H, W = batch['H'].item(), batch['W'].item() 70 | mask_at_box = mask_at_box.reshape(H, W) 71 | surf_mask = mask_at_box.copy() 72 | surf_mask[mask_at_box] = output['surf_mask'][0].detach().cpu().numpy() 73 | 74 | normal_map = np.zeros((H, W, 3)) 75 | normal_map[surf_mask] = output['surf_normal'][ 76 | output['surf_mask']].detach().cpu().numpy() 77 | 78 | normal_map[..., 1:] = normal_map[..., 1:] * -1 79 | norm = np.linalg.norm(normal_map, axis=2) 80 | norm[norm < 1e-8] = 1e-8 81 | normal_map = normal_map / norm[..., None] 82 | normal_map = (normal_map + 1) / 2 83 | 84 | plt.imshow(normal_map) 85 | plt.show() 86 | 87 | def visualize_acc(self, output, batch): 88 | acc_pred = output['acc_map'][0].detach().cpu().numpy() 89 | 90 | mask_at_box = batch['mask_at_box'][0].detach().cpu().numpy() 91 | H, W = int(cfg.H * cfg.ratio), int(cfg.W * cfg.ratio) 92 | mask_at_box = mask_at_box.reshape(H, W) 93 | 94 | acc = np.zeros((H, W)) 95 | acc[mask_at_box] = acc_pred 96 | 97 | plt.imshow(acc) 98 | plt.show() 99 | 100 | # acc_path = os.path.join(cfg.result_dir, 'acc') 101 | # i = batch['i'].item() 102 | # cam_ind = batch['cam_ind'].item() 103 | # acc_path = os.path.join(acc_path, '{:04d}_{:02d}.jpg'.format(i, cam_ind)) 104 | # os.system('mkdir -p {}'.format(os.path.dirname(acc_path))) 105 | # plt.savefig(acc_path) 106 | 107 | def visualize_depth(self, output, batch): 108 | depth_pred = output['depth_map'][0].detach().cpu().numpy() 109 | 110 | mask_at_box = batch['mask_at_box'][0].detach().cpu().numpy() 111 | H, W = int(cfg.H * cfg.ratio), int(cfg.W * cfg.ratio) 112 | mask_at_box = mask_at_box.reshape(H, W) 113 | 114 | depth = np.zeros((H, W)) 115 | depth[mask_at_box] = depth_pred 116 | 117 | plt.imshow(depth) 118 | plt.show() 119 | 120 | # depth_path = os.path.join(cfg.result_dir, 'depth') 121 | # i = batch['i'].item() 122 | # cam_ind = batch['cam_ind'].item() 123 | # depth_path = os.path.join(depth_path, '{:04d}_{:02d}.jpg'.format(i, cam_ind)) 124 | # os.system('mkdir -p {}'.format(os.path.dirname(depth_path))) 125 | # plt.savefig(depth_path) 126 | 127 | def visualize(self, output, batch, split='vis'): 128 | if split == 'vis' or split == 'prune': 129 | self.visualize_image(output, batch) 130 | if split == 'prune': 131 | latest_dir = osp.join(cfg.result_dir, 'latest') 132 | new_link = os.path.basename(self.result_dir) 133 | if osp.exists(latest_dir) and osp.islink(latest_dir): 134 | print("Found old latest dir link {} which link to {}, replacing it to {}".format(latest_dir, os.readlink(latest_dir), self.result_dir)) 135 | os.unlink(latest_dir) 136 | os.symlink(new_link, latest_dir) 137 | elif split == 'tmesh': 138 | breakpoint() 139 | target_path = os.path.join(cfg.result_dir, 'tmesh_{}.npy'.format(self.name)) 140 | import mcubes 141 | import trimesh 142 | np.save(target_path, output['occ']) 143 | create_link(osp.join(cfg.result_dir, "latest.npy"), target_path) 144 | # saving mesh for reference. 145 | cube = output['occ'] 146 | breakpoint() 147 | # thresh = np.median(cube[cube > 0].reshape(-1)) 148 | # thresh = 0.05 149 | # ind = np.argpartition(nonz_error_map, -sample_coord_len)[-sample_coord_len:] 150 | N = (cube > -1).sum() 151 | # NN = int(N * 0.15) 152 | NN = int(N * 0.1) 153 | ccube = cube.reshape(-1) 154 | ind = np.argpartition(ccube, -NN)[-NN:] 155 | thresh = ccube[ind].min() 156 | thresh = 0.1 157 | 158 | cube = np.pad(cube, 10, mode='constant') 159 | verts, triangles = mcubes.marching_cubes(cube, thresh) 160 | verts = (verts - 10) * cfg.voxel_size[0] 161 | verts = verts + batch['tbounds'][0, 0].detach().cpu().numpy() 162 | mesh = trimesh.Trimesh(vertices=verts, faces=triangles) 163 | mesh.export(os.path.join(cfg.result_dir, "tmesh_{}.ply".format(self.name))) 164 | print(thresh) 165 | elif split == 'tdmesh': 166 | breakpoint() 167 | target_path = os.path.join(cfg.result_dir, 'tpose_deform_mesh_{}_{}.npy'.format(self.name, batch['frame_dim'][0].item())) 168 | import mcubes 169 | import trimesh 170 | np.save(target_path, output['occ']) 171 | 172 | # saving mesh for reference. 173 | cube = output['occ'] 174 | cube = np.pad(cube, 10, mode='constant') 175 | verts, triangles = mcubes.marching_cubes(cube, 0.2) 176 | verts = (verts - 10) * cfg.voxel_size[0] 177 | verts = verts + batch['tbounds'][0, 0].detach().cpu().numpy() 178 | mesh = trimesh.Trimesh(vertices=verts, faces=triangles) 179 | mesh.export(os.path.join(cfg.result_dir, "tpose_deform_mesh_{}_{}.ply".format(self.name, batch['frame_dim'][0].item()))) 180 | else: 181 | raise NotImplementedError 182 | -------------------------------------------------------------------------------- /lib/visualizers/if_nerf_demo.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from sklearn.metrics import brier_score_loss 4 | from lib.config import cfg 5 | import cv2 6 | import os 7 | from termcolor import colored 8 | from lib.utils.base_utils import create_link, get_time 9 | import os.path as osp 10 | from pathlib import Path 11 | 12 | 13 | class Visualizer: 14 | def __init__(self, name=None): 15 | if name is None: 16 | name = get_time() 17 | self.name = name 18 | self.result_dir = osp.join(cfg.result_dir, name) 19 | Path(self.result_dir).mkdir(exist_ok=True, parents=True) 20 | print( 21 | colored('the results are saved at {}'.format(self.result_dir), 22 | 'yellow')) 23 | 24 | def increase_brightness(self, img, value=30.0 / 255.): 25 | hsv = cv2.cvtColor(img.astype(np.float32), cv2.COLOR_BGR2HSV) 26 | h, s, v = cv2.split(hsv) 27 | 28 | v += value 29 | v[v > 1.0] = 1.0 30 | 31 | final_hsv = cv2.merge((h, s, v)) 32 | img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR) 33 | return img.astype(np.float64) 34 | 35 | def visualize(self, output, batch, split='vis'): 36 | rgb_pred = output['rgb_map'][0].detach().cpu().numpy() 37 | 38 | mask_at_box = batch['mask_at_box'][0].detach().cpu().numpy() 39 | H, W = batch['H'].item(), batch['W'].item() 40 | mask_at_box = mask_at_box.reshape(H, W) 41 | 42 | img_pred = np.zeros((H, W, 3)) 43 | if cfg.white_bkgd: 44 | img_pred = img_pred + 1 45 | img_pred[mask_at_box] = rgb_pred 46 | img_pred = img_pred[..., [2, 1, 0]] 47 | breakpoint() 48 | if cfg.add_brightness: 49 | img_pred = self.increase_brightness(img_pred, value=30. / 255.) 50 | 51 | img_root = self.result_dir 52 | index = batch['view_index'].item() 53 | 54 | cv2.imwrite(os.path.join(img_root, '{:04d}.png'.format(index)), 55 | img_pred * 255) 56 | 57 | def merge_into_video(self, epoch): 58 | name = cfg.exp_name + "_epoch" + str(epoch) 59 | if cfg.add_brightness: 60 | name += "_bright" 61 | cmd = "ffmpeg -r 20 -i {}/%04d.png -c:v libx264 -vf fps=20 -pix_fmt yuv420p {}.mp4".format(self.result_dir, osp.join(self.result_dir, name)) 62 | print(cmd) 63 | os.system(cmd) 64 | cmd2 = "ffmpeg -r 20 -i {}/%04d.png {}.gif".format(self.result_dir, osp.join(self.result_dir, name)) 65 | print(cmd2) 66 | os.system(cmd2) -------------------------------------------------------------------------------- /lib/visualizers/make_visualizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | def make_visualizer(cfg, name=None, split='test'): 5 | if hasattr(getattr(cfg, split), "visualizer_module"): 6 | module = getattr(getattr(cfg, split), "visualizer_module") 7 | else: 8 | module = cfg.visualizer_module 9 | visualizer = importlib.import_module(module).Visualizer(name) 10 | return visualizer 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ipdb 2 | colored_traceback 3 | opencv-python==4.7.0.68 4 | sympy==1.11.1 5 | tensorboardX==2.6 6 | imageio==2.25.1 7 | trimesh==3.20.0 8 | matplotlib==3.7.0 9 | matplotlib-inline==0.1.6 10 | plyfile==0.7.4 11 | lpips==0.1.4 12 | scikit-image==0.19.3 13 | -------------------------------------------------------------------------------- /scripts/eval_monocap.sh: -------------------------------------------------------------------------------- 1 | export GPUS="0," # change to your gpu id 2 | 3 | for name in lan marc olek vlad 4 | do 5 | python train_net.py --cfg_file configs/inb/inb_${name}.yaml exp_name inb_${name} gpus ${GPUS} silent True 6 | python run.py --type evaluate --cfg_file configs/inb/inb_${name}.yaml exp_name inb_${name} gpus ${GPUS} 7 | done -------------------------------------------------------------------------------- /scripts/eval_zjumocap.sh: -------------------------------------------------------------------------------- 1 | export GPUS="0," # change to your gpu id 2 | 3 | for name in 377 386 387 392 393 394 4 | do 5 | python train_net.py --cfg_file configs/inb/inb_${name}.yaml exp_name inb_${name} gpus ${GPUS} silent True 6 | python run.py --type evaluate --cfg_file configs/inb/inb_${name}.yaml exp_name inb_${name} gpus ${GPUS} 7 | done -------------------------------------------------------------------------------- /tools/cropschp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import json 4 | import cv2 5 | import imageio 6 | import matplotlib.pyplot as plt 7 | import tqdm 8 | import argparse 9 | 10 | def myprint(cmd, level): 11 | color = {'run': 'blue', 'info': 'green', 'warn': 'yellow', 'error': 'red'}[level] 12 | print(colored(cmd, color)) 13 | 14 | def log(text): 15 | myprint(text, 'info') 16 | 17 | def mywarn(text): 18 | myprint(text, 'warn') 19 | 20 | def cal_Square(bbox): 21 | if bbox is None or bbox[4]<0.6:return 0 22 | return (bbox[0]-bbox[2])*(bbox[1]-bbox[3]) 23 | 24 | 25 | if __name__ == '__main__': 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--data_root", default="data", type=str) 28 | args = parser.parse_args() 29 | 30 | data_root = args.data_root 31 | 32 | annots_path = os.path.join(data_root, 'annots') 33 | schp_path = os.path.join(data_root, 'schp') 34 | 35 | for sub_dir in os.listdir(schp_path): 36 | sub_schp_path = os.path.join(schp_path, sub_dir) 37 | sub_annot_path = os.path.join(annots_path, sub_dir) 38 | for schp_img in tqdm.tqdm(sorted(os.listdir(sub_schp_path))[0:]): 39 | full_schp_img = os.path.join(sub_schp_path, schp_img) 40 | full_annot = os.path.join(sub_annot_path, schp_img.replace('.png', '.json')) 41 | 42 | # print(annot_path) 43 | with open(full_annot, 'r') as f: 44 | annot = json.load(f) 45 | # print(annot['annots']) 46 | 47 | max_bbox = None 48 | for person in annot['annots']: 49 | # print(person) 50 | # print(cal_Square(person['bbox'])) 51 | if(cal_Square(max_bbox)