├── .DS_Store ├── README.md ├── cameras.html ├── cameras.png ├── configs ├── base_config.yaml └── pnr.yaml ├── dataio ├── PolData.py ├── __init__.py ├── polanalyser │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── demosaicing.cpython-310.pyc │ │ ├── demosaicing.cpython-37.pyc │ │ ├── demosaicing.cpython-38.pyc │ │ ├── demosaicing.cpython-39.pyc │ │ ├── mueller.cpython-310.pyc │ │ ├── mueller.cpython-37.pyc │ │ ├── mueller.cpython-38.pyc │ │ ├── mueller.cpython-39.pyc │ │ ├── stokes.cpython-310.pyc │ │ ├── stokes.cpython-37.pyc │ │ ├── stokes.cpython-38.pyc │ │ ├── stokes.cpython-39.pyc │ │ ├── stokes.cvtStokesToAoLP-116.py37m.1.nbc │ │ ├── stokes.cvtStokesToAoLP-116.py37m.2.nbc │ │ ├── stokes.cvtStokesToAoLP-116.py37m.nbi │ │ ├── stokes.cvtStokesToDoLP-96.py37m.1.nbc │ │ ├── stokes.cvtStokesToDoLP-96.py37m.2.nbc │ │ ├── stokes.cvtStokesToDoLP-96.py37m.nbi │ │ ├── stokes.cvtStokesToIntensity-135.py37m.1.nbc │ │ ├── stokes.cvtStokesToIntensity-135.py37m.2.nbc │ │ ├── stokes.cvtStokesToIntensity-135.py37m.3.nbc │ │ ├── stokes.cvtStokesToIntensity-135.py37m.nbi │ │ ├── util.cpython-310.pyc │ │ ├── util.cpython-37.pyc │ │ ├── util.cpython-38.pyc │ │ └── util.cpython-39.pyc │ ├── demosaicing.py │ ├── mueller.py │ ├── stokes.py │ └── util.py ├── polprocess │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── test.cpython-37.pyc │ │ ├── tools.cpython-37.pyc │ │ └── tools.cpython-38.pyc │ ├── camera_npy_parser.py │ ├── camera_txt_parser.py │ ├── lucid_isp.py │ ├── mitsuba_isp.py │ ├── pandora_isp.py │ └── tools.py └── tools │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── vis_camera.cpython-37.pyc │ └── vis_camera.cpython-38.pyc │ ├── render_view.py │ ├── vis_camera.py │ ├── vis_ray.py │ └── vis_surface_and_cam.py ├── docs ├── normal_splatting.png └── pipeline.png ├── models ├── PolAnalyser.py ├── __init__.py ├── base.py ├── cameras.py ├── frameworks │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── holoNeuS.cpython-37.pyc │ │ ├── neus.cpython-37.pyc │ │ ├── pneus.cpython-37.pyc │ │ ├── pneus.cpython-38.pyc │ │ ├── pnr.cpython-38.pyc │ │ ├── pvolsdf.cpython-37.pyc │ │ ├── pvolsdf_gray.cpython-37.pyc │ │ ├── pvolsdf_mono.cpython-37.pyc │ │ ├── pvolsdf_sRGB.cpython-37.pyc │ │ ├── sslpneus.cpython-37.pyc │ │ ├── sslpvolsdf.cpython-37.pyc │ │ ├── unisurf.cpython-37.pyc │ │ └── volsdf.cpython-37.pyc │ └── pnr.py ├── loss.py ├── math_utils.py ├── ray_casting.py └── ray_sampler.py ├── render_view.sh ├── requirements.txt ├── sdf2mesh.py ├── sdf2msh_volsdf.py ├── surface.json ├── tools ├── 360cameraPath │ ├── camera_extrinsics.json │ └── camera_intrinsics.json ├── __init__.py ├── azi2aop.py ├── eval_3dprint.py ├── extract_surface.py ├── render_view.py ├── vis_camera.py ├── vis_ray.py └── vis_surface_and_cam.py ├── train.py ├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── checkpoints.cpython-37.pyc │ ├── checkpoints.cpython-38.pyc │ ├── dist_util.cpython-310.pyc │ ├── dist_util.cpython-37.pyc │ ├── dist_util.cpython-38.pyc │ ├── dist_util.cpython-39.pyc │ ├── general.cpython-37.pyc │ ├── general.cpython-38.pyc │ ├── io_util.cpython-310.pyc │ ├── io_util.cpython-37.pyc │ ├── io_util.cpython-38.pyc │ ├── log_utils.cpython-38.pyc │ ├── logger.cpython-310.pyc │ ├── logger.cpython-37.pyc │ ├── logger.cpython-38.pyc │ ├── mesh_util.cpython-37.pyc │ ├── mesh_util.cpython-38.pyc │ ├── plots.cpython-37.pyc │ ├── plots.cpython-38.pyc │ ├── print_fn.cpython-310.pyc │ ├── print_fn.cpython-37.pyc │ ├── print_fn.cpython-38.pyc │ ├── print_fn.cpython-39.pyc │ ├── rend_util.cpython-37.pyc │ ├── rend_util.cpython-38.pyc │ ├── train_util.cpython-37.pyc │ └── train_util.cpython-38.pyc ├── checkpoints.py ├── dist_util.py ├── general.py ├── io_util.py ├── log_utils.py ├── logger.py ├── mesh_util.py ├── plots.py ├── print_fn.py ├── rend_util.py └── train_util.py ├── val.py └── vis_weights.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # GNeRP: Gaussian-guided Neural Reconstruction of Reflective Objects with Noisy Polarization Priors (ICLR 2024) 4 | 5 | ## [Project Page](https://yukiumi13.github.io/gnerp_page/) | [Paper](https://yukiumi13.github.io/gnerp_page/gnerp_camera_ready.pdf) | [Dataset](https://drive.google.com/drive/folders/19j1Px5hT74dpZwKRgX0pfycr35AsbKgj?usp=sharing) | [Citation](##citation) 6 | 7 | This is the official repo for the implementation of [GNeRP: Gaussian-guided Neural Reconstruction of 8 | Reflective Objects with Noisy Polarization Priors](https://iclr.cc/virtual/2024/poster/17774), Yang LI, Ruizheng WU, Jiyong LI, Ying-Cong Chen. 9 | 10 | ## 📣:I'm actively looking for Ph.D. positions. Please see my [CV](https://yukiumi13.github.io/liyang.pdf) if interested. 11 | 12 | ## Abstract 13 | * 🚀 We proposed Gaussian Splatting of surface **normals** dedicated to reflective objects. 14 |  15 | * 🚀 It is built upon NeuS and supervised by **Polarization Priors**. 16 |  17 | 18 | ## Environment Setup 19 | 20 | The repo is built upon pytorch 1.13.1 with CUDA 11.6. Additional packages are listed in ```requirements.txt```. 21 | 22 | ## Training and Inference 23 | 24 | ### Coordinates System Conventions 25 | 26 | The coordinates system used in ray casting follows [OpenCV](https://docs.opencv.org/4.x/d9/d0c/group__calib3d.html). It is **NOTED** that the origin of **POLARIZATION** image plane is at the **RIGHT-BOTTOM** corner of the image, while the origin of Radiance image plane is at the **LEFT-TOP** corner of the image. It determines the formulation of AoP Loss in the code. For details, please refer to PolRef Dataset. 27 | 28 | For custom datasets, coordinate system conventions should be followed strictly, which can be checked by the visualization tools ```tools/{vis_camera, vis_ray, vis_surface_and_cam}.py``` and ``` dataio/PolData.py```. 29 | 30 | ### Data Formats 31 | 32 | The camera parameters are stored as JSON files in the following format: 33 | 34 | ``` 35 | { 36 | "w2c_mat_0": [ 37 | [R_3x3, t_3x1], 38 | [0.0_{3x1}, 1.0] 39 | ], 40 | ... 41 | "intrinsic": [ 42 | [fx,0.0,tx], 43 | [0.0,fy,ty], 44 | [0.0,0.0,1.0] 45 | ] 46 | } 47 | ``` 48 | The scripts converting other formats can be found in ```dataio/{camera_npy_parser,camera_txt_parser}.py```. 49 | 50 | AoP and DoP images are stored as Numpy Arrays. The pre-procsess scripts from the raw polarization capture to radiance and AoP/DoP images are provided in ```dataio/{lucid_isp, mitsuba_isp, pandora_isp}.py```. Samples can be found in PolRef Dataset. 51 | ### Train a model from scratch 52 | 53 | Run the evaluation script with 54 | 55 | ```python -m train --config configs/hyper.yaml --base_config configs/pnr.yaml ``` 56 | 57 | ### Extract Geometry 58 | Typically, meshes can be extracted from checkpoints using the following command: 59 | 60 | ```python sdf2mesh.py --resolution 512 --json surface.json ``` 61 | 62 | Moreover, we provide a script ```sdf2mesh_volsdf.py``` of extract method in VolSDF, which included additional post-process procedures. It is much slower but the mesh is smoothed. 63 | ## PolRef Dataset 64 | The dataset is split into real scenes and synthetic scenes. Camera parameters following the convention. The data is organized as follows (similar to [PANDORA](https://github.com/akshatdave/pandora)): 65 |
66 | +-- data 67 | | +-- ironman 68 | | +-- aop 69 | | +-- dop 70 | | +-- images 71 | | +-- masks 72 | | +-- masks_ignore 73 | | +-- cameras.json 74 | | 75 | +-- ... 76 |77 | 78 | ## TODO 79 | 80 | - [x] release training code. 81 | - [x] release PolRef Synthetic Dataset. 82 | - [x] release PolRef Real Dataset. 83 | 84 | ## Citation 85 | 86 | 87 | If you find our work useful in your research, please consider citing: 88 | 89 | ``` 90 | @article{li24gnerp, 91 | title={GNeRP: Gaussian-guided Neural Reconstruction of Reflective Objects with Noisy Polarization Priors}, 92 | author={Li, Yang and Wu, Ruizheng and Li, Jiyong and Chen, Ying-Cong}, 93 | journal={ICLR}, 94 | year={2024} 95 | } 96 | ``` 97 | 98 | 99 | ## Acknowledgments 100 | 101 | Our code is partially based on [neurecon](https://github.com/ventusff/neurecon) project and some code snippets are borrowed from [Ref-NeuS](https://github.com/EnVision-Research/Ref-NeuS). Thanks for these great projects. 102 | -------------------------------------------------------------------------------- /cameras.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/cameras.png -------------------------------------------------------------------------------- /configs/base_config.yaml: -------------------------------------------------------------------------------- 1 | expname: loss_mean 2 | parent_config: /dataset/yokoli/neurecon/configs/pnr.yaml 3 | # expname: neus 4 | 5 | # device_ids: [0] # single gpu ; run on specified GPU 6 | # device_ids: [1, 0] # DP ; run on specified GPU 7 | device_ids: -1 # single GPU / DP / DDP; run on all available GPUs; 8 | 9 | data: 10 | data_dir: /dataset/yokoli/data/pol/duck 11 | # crop_quantile: [3,2] 12 | opengl: True 13 | scale_radius: 2.0 14 | chromatic: sRGB 15 | 16 | model: 17 | surface: 18 | D: 8 19 | W: 256 20 | skips: [4] 21 | radius_init: 0.5 22 | embed_multires: 8 23 | 24 | training: 25 | pol: 26 | splat: True 27 | loss: 28 | # lambda config 29 | w_splat: 0.0 30 | w_eik: 0.1 31 | w_aop: 0.1 32 | w_mask: 0.1 33 | w_rgb: 1.0 34 | pol_rew: False 35 | splat_rew: False 36 | # clip config 37 | dop_upper: -1 38 | # train scheduler 39 | pol_start_it: -1 40 | splat_start_it: 50000 41 | # mask config 42 | aop_mask: True 43 | # objective function config 44 | opengl: True # NOTE: ensure the alignment with data.opengl! 45 | normal_perspective: True 46 | svd_sup: True 47 | 48 | log_root_dir: "logs/pnr/duck" 49 | num_iters: 200000 # 300k 50 | i_val: 10000 51 | i_val_mesh: 10000 -------------------------------------------------------------------------------- /configs/pnr.yaml: -------------------------------------------------------------------------------- 1 | expname: ??? 2 | 3 | device_ids: -1 # single GPU / DP / DDP; run on all available GPUs; 4 | 5 | data: 6 | type: PolData 7 | batch_size: 1 # MUST be one 8 | data_dir: ??? 9 | downscale: 1 # downscale image for training 10 | scale_radius: 2.0 # scale the dataset's all camera to be within this radius 11 | chromatic: sRGB 12 | opengl: ??? 13 | pin_memory: True 14 | 15 | N_rays: 512 # N_rays for training 16 | val_rayschunk: 256 # N_rays for validation 17 | val_downscale: 2 # downscale image for validation 18 | opengl: true 19 | 20 | model: 21 | framework: pnr 22 | obj_bounding_radius: 1.0 23 | variance_init: 0.05 24 | # N_outside: 32 # number of outside NeRF++ points. If with_mask, MUST BE ZERO 25 | 26 | # upsampling related 27 | N_upsample_iters: 4 # config for upsampling using 'official_solution' 28 | 29 | normal_splatting: True 30 | normal_gaussian_estimate: True 31 | gaussian_scale_factor: 1.0 32 | 33 | surface: 34 | D: 8 35 | W: 256 36 | skips: [4] 37 | radius_init: 0.5 38 | embed_multires: 8 39 | 40 | radiance: 41 | D: 4 42 | W: 256 43 | skips: [] 44 | embed_multires: -1 45 | embed_multires_view: 4 # as in the NeuS official implementaion 46 | 47 | training: 48 | lr: 5.0e-4 49 | speed_factor: 10.0 # NOTE: unexpectedly, this is very important. setting to 1.0 will cause some of the DTU instances can not converge correctly. 50 | 51 | # neus 52 | with_mask: true # NeRF++ if false 53 | loss: 54 | w_splat: 0.1 55 | w_eik: 0.1 56 | w_aop: 0.1 57 | w_mask: 0.1 58 | w_rgb: 0.1 59 | pol_start_it: -1 60 | splat_start_it: 50000 61 | aop_mask: True 62 | pol_rew: True 63 | dop_upper: -1 64 | normal_perspective: True 65 | 66 | log_root_dir: "???" 67 | 68 | # lr decay 69 | scheduler: 70 | type: warmupcosine 71 | warmup_steps: 5000 # unit: itertation steps 72 | 73 | # num_epochs: 50000 74 | num_iters: 100000 # 300k 75 | 76 | ckpt_file: null # will be read by python as None 77 | ckpt_ignore_keys: [] # only change if you want to drop certain keys in the saved checkpionts. 78 | ckpt_only_use_keys: null # only change if you want to only use certain keys in the saved checkpionts. 79 | 80 | monitoring: tensorboard 81 | 82 | i_save: 900 # unit: seconds 83 | i_backup: 50000 # unit: itertation steps 84 | 85 | i_val: 5000 86 | i_val_mesh: 10000 -------------------------------------------------------------------------------- /dataio/__init__.py: -------------------------------------------------------------------------------- 1 | def get_data(args, return_val=False, val_downscale=4.0, **overwrite_cfgs): 2 | dataset_type = args.data.get('type', 'DTU') 3 | cfgs = { 4 | 'scale_radius': args.data.get('scale_radius', -1), 5 | 'downscale': args.data.downscale, 6 | 'data_dir': args.data.data_dir, 7 | 'train_cameras': False, 8 | 'chromatic': args.data.get('chromatic', None), 9 | 'opengl': args.data.get('opengl', False), 10 | 'crop_quantile': args.data.get('crop_quantile', None) 11 | } 12 | 13 | if dataset_type == 'DTU': 14 | cfgs = { 15 | 'scale_radius': args.data.get('scale_radius', -1), 16 | 'downscale': args.data.downscale, 17 | 'data_dir': args.data.data_dir, 18 | 'train_cameras': False, 19 | # 'chromatic': args.data.get(('chromatic', None)) 20 | } 21 | 22 | from .DTU import SceneDataset 23 | cfgs['cam_file'] = args.data.get('cam_file', None) 24 | elif dataset_type == 'custom': 25 | from .custom import SceneDataset 26 | elif dataset_type == 'BlendedMVS': 27 | from .BlendedMVS import SceneDataset 28 | elif dataset_type == 'PolData': 29 | from .PolData import SceneDataset 30 | elif dataset_type == 'normalData': 31 | from .normalData import SceneDataset 32 | else: 33 | raise NotImplementedError 34 | 35 | cfgs.update(overwrite_cfgs) 36 | dataset = SceneDataset(**cfgs) 37 | if return_val: 38 | cfgs['downscale'] = val_downscale 39 | val_dataset = SceneDataset(**cfgs) 40 | return dataset, val_dataset 41 | else: 42 | return dataset -------------------------------------------------------------------------------- /dataio/polanalyser/__init__.py: -------------------------------------------------------------------------------- 1 | from .stokes import * 2 | from .mueller import * 3 | from .demosaicing import * 4 | -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/demosaicing.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/demosaicing.cpython-310.pyc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/demosaicing.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/demosaicing.cpython-37.pyc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/demosaicing.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/demosaicing.cpython-38.pyc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/demosaicing.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/demosaicing.cpython-39.pyc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/mueller.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/mueller.cpython-310.pyc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/mueller.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/mueller.cpython-37.pyc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/mueller.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/mueller.cpython-38.pyc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/mueller.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/mueller.cpython-39.pyc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/stokes.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/stokes.cpython-310.pyc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/stokes.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/stokes.cpython-37.pyc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/stokes.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/stokes.cpython-38.pyc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/stokes.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/stokes.cpython-39.pyc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/stokes.cvtStokesToAoLP-116.py37m.1.nbc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/stokes.cvtStokesToAoLP-116.py37m.1.nbc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/stokes.cvtStokesToAoLP-116.py37m.2.nbc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/stokes.cvtStokesToAoLP-116.py37m.2.nbc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/stokes.cvtStokesToAoLP-116.py37m.nbi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/stokes.cvtStokesToAoLP-116.py37m.nbi -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/stokes.cvtStokesToDoLP-96.py37m.1.nbc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/stokes.cvtStokesToDoLP-96.py37m.1.nbc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/stokes.cvtStokesToDoLP-96.py37m.2.nbc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/stokes.cvtStokesToDoLP-96.py37m.2.nbc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/stokes.cvtStokesToDoLP-96.py37m.nbi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/stokes.cvtStokesToDoLP-96.py37m.nbi -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/stokes.cvtStokesToIntensity-135.py37m.1.nbc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/stokes.cvtStokesToIntensity-135.py37m.1.nbc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/stokes.cvtStokesToIntensity-135.py37m.2.nbc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/stokes.cvtStokesToIntensity-135.py37m.2.nbc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/stokes.cvtStokesToIntensity-135.py37m.3.nbc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/stokes.cvtStokesToIntensity-135.py37m.3.nbc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/stokes.cvtStokesToIntensity-135.py37m.nbi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/stokes.cvtStokesToIntensity-135.py37m.nbi -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /dataio/polanalyser/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polanalyser/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /dataio/polanalyser/demosaicing.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | COLOR_PolarRGB = "COLOR_PolarRGB" 5 | COLOR_PolarMono = "COLOR_PolarMono" 6 | 7 | def demosaicing(img_raw, code=COLOR_PolarMono): 8 | """Polarization demosaicing 9 | 10 | Parameters 11 | ---------- 12 | img_raw : np.ndarry, (height, width) 13 | RAW polarization image taken with polarizatin camera 14 | (e.g. IMX250MZR or IMX250MYR sensor) 15 | code : str (optional) 16 | COLOR_PolarMono or COLOR_PolarRGB 17 | 18 | Returns 19 | ------- 20 | img_polarization: np.ndarray 21 | Dmosaiced image. 0-45-90-135. 22 | """ 23 | if code == COLOR_PolarMono: 24 | if img_raw.dtype == np.uint8 or img_raw.dtype == np.uint16: 25 | return __demosaicing_mono_uint(img_raw) 26 | else: 27 | return __demosaicing_mono_float(img_raw) 28 | elif code == COLOR_PolarRGB: 29 | if img_raw.dtype == np.uint8 or img_raw.dtype == np.uint16: 30 | return __demosaicing_color(img_raw) 31 | else: 32 | raise TypeError("dtype of `img_raw` must be np.uint8 or np.uint16") 33 | else: 34 | raise ValueError(f"`code` must be {COLOR_PolarMono} or {COLOR_PolarRGB}") 35 | 36 | def __demosaicing_mono_uint(img_mpfa): 37 | """Polarization demosaicing for np.uint8 or np.uint16 type 38 | """ 39 | img_debayer_bg = cv2.cvtColor(img_mpfa, cv2.COLOR_BayerBG2BGR) 40 | img_debayer_gr = cv2.cvtColor(img_mpfa, cv2.COLOR_BayerGR2BGR) 41 | img_0, _, img_90 = np.moveaxis(img_debayer_bg, -1, 0) 42 | img_45, _, img_135 = np.moveaxis(img_debayer_gr, -1, 0) 43 | img_polarization = np.array([img_0, img_45, img_90, img_135], dtype=img_mpfa.dtype) 44 | img_polarization = np.moveaxis(img_polarization, 0, -1) 45 | return img_polarization 46 | 47 | def __demosaicing_mono_float(img_mpfa): 48 | """Polarization demosaicing for arbitrary type 49 | 50 | cv2.cvtColor supports either uint8 or uint16 type. 51 | Float type bayer is demosaiced by this function. 52 | 53 | Notes 54 | ----- 55 | pros: slow 56 | cons: float available 57 | """ 58 | height, width = img_mpfa.shape[:2] 59 | img_subsampled = np.zeros((height, width, 4), dtype=img_mpfa.dtype) 60 | 61 | img_subsampled[0::2, 0::2, 0] = img_mpfa[0::2, 0::2] 62 | img_subsampled[0::2, 1::2, 1] = img_mpfa[0::2, 1::2] 63 | img_subsampled[1::2, 0::2, 2] = img_mpfa[1::2, 0::2] 64 | img_subsampled[1::2, 1::2, 3] = img_mpfa[1::2, 1::2] 65 | 66 | kernel = np.array([[1/4, 1/2, 1/4], 67 | [1/2, 1.0, 1/2], 68 | [1/4, 1/2, 1/4]]) 69 | 70 | img_polarization = cv2.filter2D(img_subsampled, -1, kernel) 71 | 72 | return img_polarization[..., [3, 1, 0, 2]] 73 | 74 | def __demosaicing_color(img_cpfa): 75 | """Color-Polarization demosaicing for np.uint8 or np.uint16 type 76 | """ 77 | height, width = img_cpfa.shape[:2] 78 | 79 | # 1. Color demosaicing process 80 | img_mpfa_bgr = np.empty((height, width, 3), dtype=img_cpfa.dtype) 81 | for j in range(2): 82 | for i in range(2): 83 | # (i, j) 84 | # (0, 0) is 90, (0, 1) is 45 85 | # (1, 0) is 135, (1, 1) is 0 86 | 87 | # Down sampling ↓2 88 | img_bayer_ij = img_cpfa[j::2, i::2] 89 | # Color demosaicking 90 | img_bgr_ij = cv2.cvtColor(img_bayer_ij, cv2.COLOR_BayerBG2BGR) 91 | # Up samping ↑2 92 | img_mpfa_bgr[j::2, i::2] = img_bgr_ij 93 | 94 | # 2. Polarization demosaicing process 95 | img_bgr_polarization = np.empty((height, width, 3, 4), dtype=img_mpfa_bgr.dtype) 96 | for i, img_mpfa in enumerate(cv2.split(img_mpfa_bgr)): 97 | img_demosaiced = demosaicing(img_mpfa, COLOR_PolarMono) 98 | img_bgr_polarization[..., i, :] = img_demosaiced 99 | 100 | return img_bgr_polarization 101 | -------------------------------------------------------------------------------- /dataio/polanalyser/mueller.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def calcMueller(images, radians_light, radians_camera): 4 | """ 5 | Calculate mueller matrix from captured images and 6 | angles of the linear polarizer on the light side and the camera side. 7 | 8 | Parameters 9 | ---------- 10 | images : np.ndarray, (height, width, N) 11 | Captured images 12 | radians_light : np.ndarray, (N,) 13 | polarizer angles on the light side 14 | radians_camera : np.ndarray, (N,) 15 | polarizer angles on the camera side 16 | Returns 17 | ------- 18 | img_mueller : np.ndarray, (height, width, 9) 19 | Calculated mueller matrix image 20 | """ 21 | cos_light = np.cos(2*radians_light) 22 | sin_light = np.sin(2*radians_light) 23 | cos_camera = np.cos(2*radians_camera) 24 | sin_camera = np.sin(2*radians_camera) 25 | A = np.array([np.ones_like(radians_light), cos_light, sin_light, cos_camera, cos_camera*cos_light, cos_camera*sin_light, sin_camera, sin_camera*cos_light, sin_camera*sin_light]).T 26 | A_pinv = np.linalg.inv(A.T @ A) @ A.T #(9, N) 27 | img_mueller = np.tensordot(A_pinv, images, axes=(1,-1)) #(9, height, width) 28 | img_mueller = np.moveaxis(img_mueller, 0, -1) # (height, width, 9) 29 | return img_mueller 30 | 31 | def rotator(theta): 32 | """Generate Mueller matrix of rotation 33 | 34 | Parameters 35 | ---------- 36 | theta : float 37 | the angle of rotation 38 | 39 | Returns 40 | ------- 41 | mueller : np.ndarray 42 | mueller matrix (4, 4) 43 | """ 44 | ones = np.ones_like(theta) 45 | zeros = np.zeros_like(theta) 46 | sin2 = np.sin(2*theta) 47 | cos2 = np.cos(2*theta) 48 | mueller = np.array([[ones, zeros, zeros, zeros], 49 | [zeros, cos2, sin2, zeros], 50 | [zeros, -sin2, cos2, zeros], 51 | [zeros, zeros, zeros, ones]]) 52 | mueller = np.moveaxis(mueller, [0,1], [-2,-1]) 53 | return mueller 54 | 55 | def rotateMueller(mueller, theta): 56 | """Rotate Mueller matrix 57 | 58 | Parameters 59 | ---------- 60 | theta : float 61 | the angle of rotation 62 | 63 | Returns 64 | ------- 65 | mueller : np.ndarray 66 | mueller matrix (4, 4) 67 | """ 68 | return rotator(-theta) @ mueller @ rotator(theta) 69 | 70 | def polarizer(theta): 71 | """Generate Mueller matrix of linear polarizer 72 | 73 | Parameters 74 | ---------- 75 | theta : float 76 | the angle of the linear polarizer 77 | 78 | Returns 79 | ------- 80 | mueller : np.ndarray 81 | mueller matrix (4, 4) 82 | """ 83 | mueller = np.array([[0.5, 0.5, 0, 0], 84 | [0.5, 0.5, 0, 0], 85 | [ 0, 0, 0, 0], 86 | [ 0, 0, 0, 0]]) # (4, 4) 87 | mueller = rotateMueller(mueller, theta) 88 | return mueller 89 | 90 | def retarder(delta, theta): 91 | """Generate Mueller matrix of linear retarder 92 | 93 | Parameters 94 | ---------- 95 | delta : float 96 | the phase difference between the fast and slow axis 97 | theta : float 98 | the angle of the fast axis 99 | 100 | Returns 101 | ------- 102 | mueller : np.ndarray 103 | mueller matrix (4, 4) 104 | """ 105 | ones = np.ones_like(delta) 106 | zeros = np.zeros_like(delta) 107 | sin = np.sin(delta) 108 | cos = np.cos(delta) 109 | mueller = np.array([[ones, zeros, zeros, zeros], 110 | [zeros, ones, zeros, zeros], 111 | [zeros, zeros, cos, -sin], 112 | [zeros, zeros, sin, cos]]) 113 | mueller = np.moveaxis(mueller, [0,1], [-2,-1]) 114 | 115 | mueller = rotateMueller(mueller, theta) 116 | return mueller 117 | 118 | def qwp(theta): 119 | """Generate Mueller matrix of Quarter-Wave Plate (QWP) 120 | 121 | Parameters 122 | ---------- 123 | theta : float 124 | the angle of the fast axis 125 | 126 | Returns 127 | ------- 128 | mueller : np.ndarray 129 | mueller matrix (4, 4) 130 | """ 131 | return retarder(np.pi/2, theta) 132 | 133 | def hwp(theta): 134 | """Generate Mueller matrix of Half-Wave Plate (QWP) 135 | 136 | Parameters 137 | ---------- 138 | theta : float 139 | the angle of the fast axis 140 | 141 | Returns 142 | ------- 143 | mueller : np.ndarray 144 | mueller matrix (4, 4) 145 | """ 146 | return retarder(np.pi, theta) 147 | 148 | 149 | def plotMueller(filename, img_mueller, vabsmax=None, dpi=300, cmap="RdBu", add_title=True): 150 | """ 151 | Apply color map to the Mueller matrix image and save them side by side 152 | 153 | Parameters 154 | ---------- 155 | filename : str 156 | File name to be written. 157 | img_mueller : np.ndarray, (height, width, 9) or (height, width, 16) 158 | Mueller matrix image. 159 | vabsmax : float 160 | Absolute maximum value for plot. If None, the absolute maximum value of 'img_mueller' will be applied. 161 | dpi : float 162 | The resolution in dots per inch. 163 | cmap : str 164 | Color map for plot. 165 | https://matplotlib.org/3.1.0/tutorials/colors/colormaps.html 166 | add_title : bool 167 | Whether to insert a title (e.g. m11, m12...) in the image. 168 | """ 169 | import matplotlib.pyplot as plt 170 | from mpl_toolkits.axes_grid1 import ImageGrid 171 | try: 172 | plt.rcParams["mpl_toolkits.legacy_colorbar"] = False 173 | except KeyError: 174 | pass 175 | 176 | # Check for 'img_muller' shape 177 | height, width, channel = img_mueller.shape 178 | if channel==9: 179 | n = 3 180 | elif channel==16: 181 | n = 4 182 | else: 183 | raise ValueError(f"'img_mueller' shape should be (height, width, 9) or (height, width, 16): ({height}, {width}, {channel})") 184 | 185 | def add_inner_title(ax, title, loc, size=None, **kwargs): 186 | """ 187 | Insert the title inside image 188 | """ 189 | from matplotlib.offsetbox import AnchoredText 190 | from matplotlib.patheffects import withStroke 191 | if size is None: 192 | size = dict(size=plt.rcParams['legend.fontsize']) 193 | at = AnchoredText(title, loc=loc, prop=size, 194 | pad=0., borderpad=0.5, 195 | frameon=False, **kwargs) 196 | ax.add_artist(at) 197 | at.txt._text.set_path_effects([withStroke(foreground="w", linewidth=3)]) 198 | return at 199 | 200 | # Vreta figure 201 | fig = plt.figure() 202 | 203 | # Create image grid 204 | grid = ImageGrid(fig, 111, 205 | nrows_ncols=(n,n), 206 | axes_pad=0.0, 207 | share_all=True, 208 | cbar_location="right", 209 | cbar_mode="single", 210 | cbar_size="3%", 211 | cbar_pad=0.10, 212 | ) 213 | 214 | # Set absolute maximum value 215 | vabsmax = np.max(np.abs(img_mueller)) if (vabsmax is None) else vabsmax 216 | vmax = vabsmax 217 | vmin = -vabsmax 218 | 219 | # Add data to image grid 220 | for i, ax in enumerate(grid): 221 | # Remove the ticks 222 | ax.set_xticks([]) 223 | ax.set_yticks([]) 224 | 225 | # Add title 226 | if add_title: 227 | maintitle = "$m$$_{0}$$_{1}$".format(i//n+1, i%n+1) # m{}{} 228 | t = add_inner_title(ax, maintitle, loc='lower right') 229 | 230 | # Add image 231 | im = ax.imshow(img_mueller[:,:,i], 232 | vmin=vmin, 233 | vmax=vmax, 234 | cmap=cmap, 235 | ) 236 | 237 | # Colorbar 238 | cbar = ax.cax.colorbar(im, ticks=[vmin, 0, vmax]) 239 | cbar.solids.set_edgecolor("face") 240 | ax.cax.toggle_label(True) 241 | 242 | # Save figure 243 | plt.savefig(filename, bbox_inches='tight', dpi=dpi) 244 | -------------------------------------------------------------------------------- /dataio/polanalyser/stokes.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from .mueller import polarizer 4 | from .util import njit_if_available 5 | 6 | def calcStokes(intensities, muellers): 7 | """Calculate stokes vector from observed intensity and mueller matrix 8 | 9 | Parameters 10 | ---------- 11 | intensities : np.ndarray 12 | Intensity of measurements (height, width, n) 13 | muellers : np.ndarray 14 | Mueller matrix (3, 3, n) or (4, 4, n) 15 | 16 | Returns 17 | ------- 18 | stokes : np.ndarray 19 | Stokes vector (height, width, 3) or (height, width, 4) 20 | """ 21 | if not isinstance(intensities, np.ndarray): 22 | intensities = np.stack(intensities, axis=-1) # (height, width, n) 23 | 24 | if not isinstance(muellers, np.ndarray): 25 | muellers = np.stack(muellers, axis=-1) # (3, 3, n) or (4, 4, n) 26 | 27 | if muellers.ndim == 1: 28 | # 1D array case 29 | thetas = muellers 30 | return calcLinearStokes(intensities, thetas) 31 | 32 | A = muellers[0].T # [m11, m12, m13] (n, 3) or [m11, m12, m13, m14] (n, 4) 33 | A_pinv = np.linalg.pinv(A) # (3, n) 34 | stokes = np.tensordot(A_pinv, intensities, axes=(1, -1)) # (3, height, width) or (4, height, width) 35 | stokes = np.moveaxis(stokes, 0, -1) # (height, width, 3) 36 | return stokes 37 | 38 | def calcLinearStokes(intensities, thetas): 39 | """Calculate only linear polarization stokes vector from observed intensity and linear polarizer angle 40 | 41 | Parameters 42 | ---------- 43 | intensities : np.ndarray 44 | Intensity of measurements (height, width, n) 45 | theta : np.ndarray 46 | Linear polarizer angles (n, ) 47 | 48 | Returns 49 | ------- 50 | S : np.ndarray 51 | Stokes vector (height, width, 3) 52 | """ 53 | muellers = [ polarizer(theta)[..., :3, :3] for theta in thetas ] 54 | return calcStokes(intensities, muellers) 55 | 56 | @njit_if_available(parallel=True, cache=True) 57 | def cvtStokesToImax(img_stokes): 58 | """ 59 | Convert stokes vector image to Imax image 60 | 61 | Parameters 62 | ---------- 63 | img_stokes : np.ndarray, (height, width, 3) 64 | Stokes vector image 65 | 66 | Returns 67 | ------- 68 | img_Imax : np.ndarray, (height, width) 69 | Imax image 70 | """ 71 | S0 = img_stokes[..., 0] 72 | S1 = img_stokes[..., 1] 73 | S2 = img_stokes[..., 2] 74 | return (S0+np.sqrt(S1**2+S2**2))*0.5 75 | 76 | @njit_if_available(parallel=True, cache=True) 77 | def cvtStokesToImin(img_stokes): 78 | """ 79 | Convert stokes vector image to Imin image 80 | 81 | Parameters 82 | ---------- 83 | img_stokes : np.ndarray, (height, width, 3) 84 | Stokes vector image 85 | 86 | Returns 87 | ------- 88 | img_Imin : np.ndarray, (height, width) 89 | Imin image 90 | """ 91 | S0 = img_stokes[..., 0] 92 | S1 = img_stokes[..., 1] 93 | S2 = img_stokes[..., 2] 94 | return (S0-np.sqrt(S1**2+S2**2))*0.5 95 | 96 | @njit_if_available(parallel=True, cache=True) 97 | def cvtStokesToDoLP(img_stokes): 98 | """ 99 | Convert stokes vector image to DoLP (Degree of Linear Polarization) image 100 | 101 | Parameters 102 | ---------- 103 | img_stokes : np.ndarray, (height, width, 3) 104 | Stokes vector image 105 | 106 | Returns 107 | ------- 108 | img_DoLP : np.ndarray, (height, width) 109 | DoLP image ∈ [0, 1] 110 | """ 111 | S0 = img_stokes[..., 0] 112 | S1 = img_stokes[..., 1] 113 | S2 = img_stokes[..., 2] 114 | return np.sqrt(S1**2+S2**2)/S0 115 | 116 | @njit_if_available(parallel=True, cache=True) 117 | def cvtStokesToAoLP(img_stokes): 118 | """ 119 | Convert stokes vector image to AoLP (Angle of Linear Polarization) image 120 | 121 | Parameters 122 | ---------- 123 | img_stokes : np.ndarray, (height, width, 3) 124 | Stokes vector image 125 | 126 | Returns 127 | ------- 128 | img_AoLP : np.ndarray, (height, width) 129 | AoLP image ∈ [0, np.pi] 130 | """ 131 | S1 = img_stokes[..., 1] 132 | S2 = img_stokes[..., 2] 133 | return np.mod(0.5*np.arctan2(S2, S1), np.pi) 134 | 135 | @njit_if_available(parallel=True, cache=True) 136 | def cvtStokesToIntensity(img_stokes): 137 | """ 138 | Convert stokes vector image to intensity image 139 | 140 | Parameters 141 | ---------- 142 | img_stokes : np.ndarray, (height, width, 3) 143 | Stokes vector image 144 | 145 | Returns 146 | ------- 147 | img_intensity : np.ndarray, (height, width) 148 | Intensity image 149 | """ 150 | S0 = img_stokes[..., 0] 151 | return S0*0.5 152 | 153 | @njit_if_available(parallel=True, cache=True) 154 | def cvtStokesToDiffuse(img_stokes): 155 | """ 156 | Convert stokes vector image to diffuse image 157 | 158 | Parameters 159 | ---------- 160 | img_stokes : np.ndarray, (height, width, 3) 161 | Stokes vector image 162 | 163 | Returns 164 | ------- 165 | img_diffuse : np.ndarray, (height, width) 166 | Diffuse image 167 | """ 168 | Imin = cvtStokesToImin(img_stokes) 169 | return 1.0*Imin 170 | 171 | @njit_if_available(parallel=True, cache=True) 172 | def cvtStokesToSpecular(img_stokes): 173 | """ 174 | Convert stokes vector image to specular image 175 | 176 | Parameters 177 | ---------- 178 | img_stokes : np.ndarray, (height, width, 3) 179 | Stokes vector image 180 | 181 | Returns 182 | ------- 183 | img_specular : np.ndarray, (height, width) 184 | Specular image 185 | """ 186 | S1 = img_stokes[..., 1] 187 | S2 = img_stokes[..., 2] 188 | return np.sqrt(S1**2+S2**2) #same as Imax-Imin 189 | 190 | @njit_if_available(parallel=True, cache=True) 191 | def cvtStokesToDoP(img_stokes): 192 | """ 193 | Convert stokes vector image to DoP (Degree of Polarization) image 194 | 195 | Parameters 196 | ---------- 197 | img_stokes : np.ndarray, (height, width, 3) 198 | Stokes vector image 199 | 200 | Returns 201 | ------- 202 | img_DoP : np.ndarray, (height, width) 203 | DoP image ∈ [0, 1] 204 | """ 205 | S0 = img_stokes[..., 0] 206 | S1 = img_stokes[..., 1] 207 | S2 = img_stokes[..., 2] 208 | S3 = img_stokes[..., 3] 209 | return np.sqrt(S1**2+S2**2+S3**2)/S0 210 | 211 | @njit_if_available(parallel=True, cache=True) 212 | def cvtStokesToEllipticityAngle(img_stokes): 213 | """ 214 | Convert stokes vector image to ellipticity angle image 215 | 216 | Parameters 217 | ---------- 218 | img_stokes : np.ndarray, (height, width, 3) 219 | Stokes vector image 220 | 221 | Returns 222 | ------- 223 | img_EllipticityAngle : np.ndarray, (height, width) 224 | ellipticity angle image ∈ [-pi/4, pi/4] 225 | """ 226 | S1 = img_stokes[..., 1] 227 | S2 = img_stokes[..., 2] 228 | S3 = img_stokes[..., 3] 229 | return 0.5*np.arctan2(S3, np.sqrt(S1**2+S2**2)) 230 | 231 | @njit_if_available(parallel=True, cache=True) 232 | def cvtStokesToDoCP(img_stokes): 233 | """ 234 | Convert stokes vector image to DoCP (Degree of Circular Polarization) image 235 | 236 | Parameters 237 | ---------- 238 | img_stokes : np.ndarray, (height, width, 3) 239 | Stokes vector image 240 | 241 | Returns 242 | ------- 243 | img_DoCP : np.ndarray, (height, width) 244 | DoCP image ∈ [-1, 1] 245 | """ 246 | S0 = img_stokes[..., 0] 247 | S3 = img_stokes[..., 3] 248 | return S3 / S0 249 | 250 | def applyColorToAoLP(img_AoLP, saturation=1.0, value=1.0): 251 | """ 252 | Apply color map to AoLP image 253 | The color map is based on HSV 254 | 255 | Parameters 256 | ---------- 257 | img_AoLP : np.ndarray, (height, width) 258 | AoLP image. The range is from 0.0 to pi. 259 | 260 | saturation : float or np.ndarray, (height, width) 261 | Saturation part (optional). 262 | If you pass DoLP image (img_DoLP) as an argument, you can modulate it by DoLP. 263 | 264 | value : float or np.ndarray, (height, width) 265 | Value parr (optional). 266 | If you pass DoLP image (img_DoLP) as an argument, you can modulate it by DoLP. 267 | """ 268 | img_ones = np.ones_like(img_AoLP) 269 | 270 | img_hue = (np.mod(img_AoLP, np.pi)/np.pi*179).astype(np.uint8) # 0~pi -> 0~179 271 | img_saturation = np.clip(img_ones*saturation*255, 0, 255).astype(np.uint8) 272 | img_value = np.clip(img_ones*value*255, 0, 255).astype(np.uint8) 273 | 274 | img_hsv = cv2.merge([img_hue, img_saturation, img_value]) 275 | img_bgr = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR) 276 | return img_bgr 277 | -------------------------------------------------------------------------------- /dataio/polanalyser/util.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | is_numba_available = importlib.util.find_spec("numba") is not None 3 | if is_numba_available: 4 | from numba import njit 5 | 6 | def njit_if_available(*args, **keywords): 7 | def _njit_if_available(func): 8 | if is_numba_available: 9 | return njit(*args, **keywords)(func) 10 | else: 11 | return func 12 | 13 | return _njit_if_available 14 | -------------------------------------------------------------------------------- /dataio/polprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polprocess/__init__.py -------------------------------------------------------------------------------- /dataio/polprocess/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polprocess/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dataio/polprocess/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polprocess/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dataio/polprocess/__pycache__/test.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polprocess/__pycache__/test.cpython-37.pyc -------------------------------------------------------------------------------- /dataio/polprocess/__pycache__/tools.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polprocess/__pycache__/tools.cpython-37.pyc -------------------------------------------------------------------------------- /dataio/polprocess/__pycache__/tools.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/polprocess/__pycache__/tools.cpython-38.pyc -------------------------------------------------------------------------------- /dataio/polprocess/camera_npy_parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('/dataset/yokoli/neurecon') 4 | 5 | import json 6 | import numpy as np 7 | import argparse 8 | 9 | def npz2json(root_path): 10 | # root_path = '/newdata/yokoli/neurecon/data/pol/o' 11 | camera_path = os.path.join(root_path, 'cameras.npz') 12 | 13 | save_path = os.path.join(root_path, 'cameras.json') 14 | 15 | cameras = {} 16 | 17 | # import IPython; IPython.embed(); exit() 18 | 19 | # for i in range(0, len(lines), 14): 20 | # filename = lines[i+1][:-1] 21 | 22 | # cameras[filename] = dict( 23 | # focal_length = list(map(float,lines[i+3][:-1].split(' '))), 24 | # image_center = list(map(float,lines[i+4][:-1].split(' '))), 25 | # translation = list(map(float,lines[i+5][:-2].split(' '))), 26 | # camera_pos = list(map(float,lines[i+6][:-2].split(' '))), 27 | # quaternion = list(map(float,lines[i+8][:-2].split(' '))), 28 | # rotation = [list(map(float, lines[i+9][:-2].split(' '))), 29 | # list(map(float, lines[i+10][:-2].split(' '))), 30 | # list(map(float, lines[i+11][:-2].split(' ')))], 31 | # ) 32 | 33 | 34 | camera_dict = np.load(camera_path) 35 | image_names = sorted([k[10:] for k in camera_dict.keys() if k.startswith('scale_mat_')]) 36 | N_images = len(image_names) 37 | print(f'{N_images} Detected.') 38 | scale_mats = [camera_dict[f'scale_mat_{image_name}'].astype(np.float32) for image_name in image_names] 39 | world_mats = [camera_dict[f'world_mat_{image_name}'].astype(np.float32) for image_name in image_names] 40 | 41 | for img_name, scale_mat, world_mat in zip(image_names, scale_mats, world_mats): 42 | print(f'Processing {img_name}') 43 | P = world_mat @ scale_mat 44 | P = P[:3, :4] 45 | intrinsics, pose = load_K_Rt_from_P(P) 46 | translation = -pose[:3,:3].T @ pose[:3,3] 47 | cameras[f'{img_name}.png'] = dict( 48 | focal_length = intrinsics[[0,1],[0,1]].astype(np.float32).tolist(), 49 | image_center = intrinsics[[0,1],[2,2]].astype(np.float32).tolist(), 50 | translation = translation.astype(np.float32).tolist(), 51 | rotation = pose[:3,:3].T.astype(np.float32).tolist(), # c2w -> w2c 52 | camera_pos=pose[:3,3].astype(np.float32).tolist() 53 | ) 54 | 55 | json_str = json.dumps(cameras, indent=4) 56 | with open(save_path, 'w') as json_file: 57 | json_file.write(json_str) 58 | 59 | 60 | 61 | if __name__ == '__main__': 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--path', type=str, default='./') 64 | opt = parser.parse_args() 65 | npz2json(opt.path) 66 | 67 | -------------------------------------------------------------------------------- /dataio/polprocess/camera_txt_parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | def txt2json(root_path): 6 | camera_path = os.path.join(root_path, 'camera.txt') 7 | save_path = os.path.join(root_path, 'camera_extrinsics.json') 8 | 9 | f = open(camera_path, "r+") 10 | 11 | lines = f.readlines() 12 | 13 | cameras = {} 14 | 15 | # import IPython; IPython.embed(); exit() 16 | 17 | for i in range(0, len(lines), 14): 18 | filename = lines[i+1][:-1] 19 | 20 | cameras[filename] = dict( 21 | focal_length = list(map(float,lines[i+3][:-1].split(' '))), 22 | image_center = list(map(float,lines[i+4][:-1].split(' '))), 23 | translation = list(map(float,lines[i+5][:-2].split(' '))), 24 | camera_pos = list(map(float,lines[i+6][:-2].split(' '))), 25 | quaternion = list(map(float,lines[i+8][:-2].split(' '))), 26 | rotation = [list(map(float, lines[i+9][:-2].split(' '))), 27 | list(map(float, lines[i+10][:-2].split(' '))), 28 | list(map(float, lines[i+11][:-2].split(' ')))], 29 | ) 30 | 31 | json_str = json.dumps(cameras, indent=4) 32 | 33 | with open(save_path, 'w') as json_file: 34 | json_file.write(json_str) 35 | 36 | if __name__ == "__main__": 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument('--path',type=str, default = './') 39 | args = parser.parse_args() 40 | txt2json(args.path) 41 | -------------------------------------------------------------------------------- /dataio/polprocess/lucid_isp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from utils.io_util import load_rgb, glob_imgs 3 | import numpy as np 4 | import polanalyser_new as pa 5 | import cv2 6 | import os 7 | import argparse 8 | from polprocess.tools import preprocess_raw_gray, print_np, convert_u8, linear_rescale, view_dop 9 | import imageio, glob 10 | 11 | parser = argparse.ArgumentParser() 12 | 13 | parser.add_argument('--instance_dir', type=str, 14 | help='input scene directory') 15 | parser.add_argument('--img_type', type=str, 16 | default='bmp', 17 | help='input scene directory') 18 | # parser.add_argument('--pose', type=str,default=False, help='convert SmartMore pose format to Colmap format') 19 | args = parser.parse_args() 20 | 21 | # instance_dir = './data/test/' 22 | # img_type = 'bmp' 23 | 24 | def calStokesFromArray(i0, i45, i90, i135, factor_0 = 1.0, factor_45 = 1.0, factor_90 = 1.0, factor_135 = 1.0): 25 | s0 = np.maximum((factor_0 * i0) + (factor_90 * i90), (factor_45 * i45) + (factor_135 * i135)) 26 | s1 = (factor_0 * i0) - (factor_90 * i90) 27 | s2 = (factor_45 * i45) - (factor_135 * i135) 28 | return np.stack([s0, s1, s2], -1) 29 | 30 | def lucid_isp(instance_dir, img_type='jpg', depth = 8): 31 | polarizerDir = f'{instance_dir}/Polarizer' 32 | imgPath = sorted(glob.glob(f'{polarizerDir}/*90deg.{img_type}')) 33 | print(imgPath) 34 | imgList=[] 35 | for p in imgPath: 36 | imgName = os.path.basename(p).split('.')[-2].split('_')[:-1] 37 | imgName = '_'.join(imgName) 38 | imgList.append(imgName) 39 | 40 | os.makedirs('{0}/DoP_vis'.format(instance_dir), exist_ok=True) 41 | os.makedirs('{0}/AoP_vis'.format(instance_dir), exist_ok=True) 42 | os.makedirs('{0}/AoP'.format(instance_dir), exist_ok=True) 43 | os.makedirs('{0}/DoP'.format(instance_dir), exist_ok=True) 44 | os.makedirs('{0}/sRGB'.format(instance_dir), exist_ok=True) 45 | os.makedirs('{0}/Radiance'.format(instance_dir), exist_ok=True) 46 | os.makedirs('{0}/AoP_vis_sat'.format(instance_dir), exist_ok=True) 47 | os.makedirs('{0}/Stokes'.format(instance_dir), exist_ok=True) 48 | 49 | print(f'{len(imgList)} Images: \n', imgList) 50 | 51 | for imgName in imgList: 52 | img0 = cv2.imread(f'{polarizerDir}/{imgName}_0.{img_type}')[:,:,0] 53 | img45 = cv2.imread(f'{polarizerDir}/{imgName}_45.{img_type}')[:,:,0] 54 | img90 = cv2.imread(f'{polarizerDir}/{imgName}_90.{img_type}')[:,:,0] 55 | img135 = cv2.imread(f'{polarizerDir}/{imgName}_135.{img_type}')[:,:,0] 56 | (H, W) = img0.shape 57 | 58 | 59 | # CFA -> CPFA 60 | imgRaw = np.zeros((2*H, 2*W)) 61 | imgRaw[::2, ::2] = img90/255. 62 | imgRaw[1::2, ::2] = img135/255. 63 | imgRaw[::2, 1::2] = img45/255. 64 | imgRaw[1::2, 1::2] = img0/255. 65 | 66 | img0, img45, img90, img135 = pa.demosaicing(imgRaw, pa.COLOR_PolarRGB) 67 | img_list = [img0, img45, img90, img135] 68 | img_list = np.stack(img_list,-1) 69 | angles = np.deg2rad([0, 45, 90, 135]) 70 | # img_stokes = calStokesFromArray(img0, img45, img90, img135) 71 | three_pinv = 1 72 | if not three_pinv: 73 | img_stokes = (pa.calcStokes(img_list, angles)) 74 | else: 75 | # Select angle that has highest intensity and remove it 76 | max_angle_ind = np.argmax(img_list.sum(axis=(0,1,2))) 77 | img_pp_channel_rem = [img_list[:,:,:,a] for a in range(len(angles)) if a != max_angle_ind] 78 | angles_rem = [angles[a] for a in range(len(angles)) if a!= max_angle_ind] 79 | img_stokes = pa.calcStokes(img_pp_channel_rem, angles_rem) 80 | # img_stokes = pa.calcStokes(img_list, angles).astype('f4') 81 | s0, s1, s2 = (img_stokes[:,:,:,0]/2.0).astype('f4'), (img_stokes[:,:,:,1]/2.0).astype('f4'), (img_stokes[:,:,:,2]/2.0).astype('f4') 82 | stokes_rgb = np.stack([s0,s1,s2],-1) 83 | # Convert BGR to GRAY 84 | s0 = cv2.cvtColor(s0, cv2.COLOR_BGR2GRAY) 85 | s1 = cv2.cvtColor(s1, cv2.COLOR_BGR2GRAY) 86 | s2 = cv2.cvtColor(s2, cv2.COLOR_BGR2GRAY) 87 | stokes = np.stack([s0,s1,s2],-1) 88 | 89 | # Intensity Calculation 90 | intensity = pa.cvtStokesToIntensity(stokes_rgb) 91 | print(f'Max Intensity: {intensity.max()}') 92 | cv2.imwrite(f'{instance_dir}/Radiance/{imgName}.png', (255*intensity).astype('u1')) 93 | cv2.imwrite(f'{instance_dir}/sRGB/{imgName}.png', convert_u8(255*intensity)) 94 | 95 | # AoP Calculation 96 | AoP = pa.cvtStokesToAoLP(stokes) 97 | np.save(f'{instance_dir}/AoP/{imgName}.npy', AoP) 98 | DoP = pa.cvtStokesToDoLP(stokes) 99 | idx = np.isnan(DoP) 100 | DoP[idx]=0.0 101 | DoP = np.clip(DoP,0.0, 1.0) 102 | print(f'Max DoP: {DoP.max()}') 103 | np.save(f'{instance_dir}/DoP/{imgName}.npy', DoP) 104 | 105 | # Visualization of DOP and AOP 106 | aop_img = pa.applyColorToAoLP(AoP) 107 | aop_img_sat = pa.applyColorToAoLP(AoP, saturation = DoP) 108 | cv2.imwrite(f'{instance_dir}/AoP_vis/{imgName}.png', aop_img) 109 | cv2.imwrite(f'{instance_dir}/AoP_vis_sat/{imgName}.png', aop_img_sat) 110 | 111 | view_dop(stokes[...,0], stokes[...,1], stokes[...,2],f'{instance_dir}/DoP_vis/{imgName}.png' ) 112 | 113 | if __name__ == '__main__': 114 | args = parser.parse_args() 115 | # cal_stokes(args.instance_dir) 116 | lucid_isp('pol/Hulk', img_type = 'png') -------------------------------------------------------------------------------- /dataio/polprocess/mitsuba_isp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/dataset/yokoli/neurecon') 3 | from utils.io_util import load_rgb, glob_imgs 4 | import numpy as np 5 | import dataio.polanalyser as pa 6 | import cv2 7 | import os 8 | import argparse 9 | from dataio.polprocess.tools import preprocess_raw_gray, print_np, convert_u8, linear_rescale, view_dop 10 | import imageio 11 | 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument('--instance_dir', type=str, 15 | help='input scene directory') 16 | parser.add_argument('--img_type', type=str, 17 | default='bmp', 18 | help='input scene directory') 19 | # parser.add_argument('--pose', type=str,default=False, help='convert SmartMore pose format to Colmap format') 20 | args = parser.parse_args() 21 | 22 | # instance_dir = './data/test/' 23 | # img_type = 'bmp' 24 | 25 | def cal_stokes(instance_dir): 26 | n_images = len(glob_imgs(f'{instance_dir}/stokes'))//3 27 | print(n_images) 28 | # os.makedirs('{0}/images'.format(instance_dir), exist_ok=True) 29 | os.makedirs('{0}/stokes'.format(instance_dir), exist_ok=True) 30 | # os.makedirs('{0}/vis_stokes'.format(instance_dir), exist_ok=True) 31 | # os.makedirs('{0}/calibration'.format(instance_dir), exist_ok=True) 32 | os.makedirs('{0}/DoP_vis'.format(instance_dir), exist_ok=True) 33 | os.makedirs('{0}/AoP_vis'.format(instance_dir), exist_ok=True) 34 | os.makedirs('{0}/aop'.format(instance_dir), exist_ok=True) 35 | os.makedirs('{0}/dop'.format(instance_dir), exist_ok=True) 36 | os.makedirs('{0}/images'.format(instance_dir), exist_ok=True) 37 | os.makedirs('{0}/AoP_vis_sat'.format(instance_dir), exist_ok=True) 38 | for idx in range(n_images): 39 | print(idx) 40 | # exit() 41 | # img_raw = cv2.imread(file_name, 0) 42 | s0 = cv2.imread(f'{instance_dir}/stokes/{idx}_s0.hdr', flags = cv2.IMREAD_UNCHANGED) 43 | s0p1 = cv2.imread((f'{instance_dir}/stokes/{idx}_s0p1.hdr'), flags = cv2.IMREAD_UNCHANGED) 44 | s0p2 = cv2.imread((f'{instance_dir}/stokes/{idx}_s0p2.hdr'), flags = cv2.IMREAD_UNCHANGED) 45 | s1 = s0p1 - s0 46 | s2 = s0p2 - s0 47 | stokes_rgb = np.stack([s0,s1,s2],-1) 48 | # Convert BGR to GRAY 49 | s0 = cv2.cvtColor(s0, cv2.COLOR_BGR2GRAY) 50 | s1 = cv2.cvtColor(s1, cv2.COLOR_BGR2GRAY) 51 | s2 = cv2.cvtColor(s2, cv2.COLOR_BGR2GRAY) 52 | stokes = np.stack([s0,s1,s2],-1) 53 | 54 | mask = cv2.imread(f'{instance_dir}/masks/{idx}.png') 55 | 56 | # Intensity Calculation 57 | intensity = stokes_rgb[...,0] 58 | print(f'Max Intensity: {intensity.max()}') 59 | # cv2.imwrite(f'{instance_dir}/Radiance/{idx}.png', (255*intensity).astype('u1')) 60 | # cv2.imwrite(f'{instance_dir}/images/{idx}.png', convert_u8(255*np.clip(intensity,0,1))) 61 | 62 | # AoP Calculation 63 | AoP = pa.cvtStokesToAoLP(stokes) 64 | AoP[~mask]=0.0 65 | np.save(f'{instance_dir}/aop/{idx}.npy', AoP) 66 | DoP = pa.cvtStokesToDoLP(stokes) 67 | DoP[~mask]=0.0 68 | DoP[np.isnan(DoP)]=0.0 69 | DoP = np.clip(DoP,0.0, 1.0) 70 | print(f'Max DoP: {DoP.max()}') 71 | np.save(f'{instance_dir}/dop/{idx}.npy', DoP) 72 | 73 | # Visualization of DOP and AOP 74 | aop_img = pa.applyColorToAoLP(AoP) 75 | aop_img_sat = pa.applyColorToAoLP(AoP, saturation = DoP) 76 | cv2.imwrite(f'{instance_dir}/AoP_vis/{idx}.png', aop_img) 77 | cv2.imwrite(f'{instance_dir}/AoP_vis_sat/{idx}.png', aop_img_sat) 78 | 79 | view_dop(stokes[...,0], stokes[...,1], stokes[...,2],f'{instance_dir}/DoP_vis/{idx}.png' ) 80 | 81 | if __name__ == '__main__': 82 | args = parser.parse_args() 83 | cal_stokes(args.instance_dir) -------------------------------------------------------------------------------- /dataio/polprocess/pandora_isp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/newdata/yokoli/neurecon') 3 | from utils.io_util import load_rgb, glob_imgs 4 | import numpy as np 5 | import dataio.polanalyser as pa 6 | import cv2 7 | import os 8 | import argparse 9 | from dataio.polprocess.tools import preprocess_raw_gray, print_np, convert_u8, linear_rescale, view_dop 10 | import imageio 11 | 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument('--instance_dir', type=str, 15 | help='input scene directory') 16 | parser.add_argument('--img_type', type=str, 17 | default='bmp', 18 | help='input scene directory') 19 | # parser.add_argument('--pose', type=str,default=False, help='convert SmartMore pose format to Colmap format') 20 | args = parser.parse_args() 21 | 22 | # instance_dir = './data/test/' 23 | # img_type = 'bmp' 24 | 25 | def cal_stokes(instance_dir, offset=3, img_type='bmp', depth = 8): 26 | image_dir = '{0}/sRGB'.format(instance_dir) 27 | image_paths = sorted(glob_imgs(image_dir)) 28 | # image_paths = [image_path for image_path in image_paths if image_path[image_path.rfind("_")+1:image_path.rfind(".")] == "0"] 29 | print(image_paths) 30 | # image_names = [image_path[image_path.rfind("\\") + 1: image_path.rfind("_")] for image_path in image_paths] 31 | image_names = [os.path.basename(image_path) for image_path in image_paths] 32 | # image_names = [image_path[image_path.find("\\") + 1: ] for image_path in image_paths] 33 | print(len(image_names)) 34 | # os.makedirs('{0}/images'.format(instance_dir), exist_ok=True) 35 | os.makedirs('{0}/stokes'.format(instance_dir), exist_ok=True) 36 | # os.makedirs('{0}/vis_stokes'.format(instance_dir), exist_ok=True) 37 | # os.makedirs('{0}/calibration'.format(instance_dir), exist_ok=True) 38 | os.makedirs('{0}/DoP_vis'.format(instance_dir), exist_ok=True) 39 | os.makedirs('{0}/AoP_vis'.format(instance_dir), exist_ok=True) 40 | os.makedirs('{0}/AoP'.format(instance_dir), exist_ok=True) 41 | os.makedirs('{0}/DoP'.format(instance_dir), exist_ok=True) 42 | os.makedirs('{0}/sRGB'.format(instance_dir), exist_ok=True) 43 | os.makedirs('{0}/Radiance'.format(instance_dir), exist_ok=True) 44 | os.makedirs('{0}/AoP_vis_sat'.format(instance_dir), exist_ok=True) 45 | for image_name in image_names: 46 | print(image_name) 47 | # exit() 48 | file_name = os.path.join(image_dir,image_name) 49 | # img_raw = cv2.imread(file_name, 0) 50 | image_name_save=image_name.split(".")[0] 51 | s0 = cv2.imread(f'{instance_dir}/images_stokes/{image_name_save}_s0.hdr', flags = cv2.IMREAD_UNCHANGED) 52 | s0p1 = cv2.imread((f'{instance_dir}/images_stokes/{image_name_save}_s0p1.hdr'), flags = cv2.IMREAD_UNCHANGED) 53 | s0p2 = cv2.imread((f'{instance_dir}/images_stokes/{image_name_save}_s0p2.hdr'), flags = cv2.IMREAD_UNCHANGED) 54 | s1 = s0p1 - s0 55 | s2 = s0p2 - s0 56 | stokes_rgb = np.stack([s0,s1,s2],-1) 57 | # Convert BGR to GRAY 58 | s0 = cv2.cvtColor(s0, cv2.COLOR_BGR2GRAY) 59 | s1 = cv2.cvtColor(s1, cv2.COLOR_BGR2GRAY) 60 | s2 = cv2.cvtColor(s2, cv2.COLOR_BGR2GRAY) 61 | stokes = np.stack([s0,s1,s2],-1) 62 | 63 | # Intensity Calculation 64 | intensity = pa.cvtStokesToIntensity(stokes_rgb) 65 | print(f'Max Intensity: {intensity.max()}') 66 | cv2.imwrite(f'{instance_dir}/Radiance/{image_name_save}.png', (255*intensity).astype('u1')) 67 | cv2.imwrite(f'{instance_dir}/sRGB/{image_name_save}.png', convert_u8(255*intensity)) 68 | 69 | # AoP Calculation 70 | AoP = pa.cvtStokesToAoLP(stokes) 71 | np.save(f'{instance_dir}/AoP/{image_name_save}.npy', AoP) 72 | DoP = pa.cvtStokesToDoLP(stokes) 73 | DoP = np.clip(DoP,0.0, 1.0) 74 | print(f'Max DoP: {DoP.max()}') 75 | np.save(f'{instance_dir}/DoP/{image_name_save}.npy', DoP) 76 | 77 | # Visualization of DOP and AOP 78 | aop_img = pa.applyColorToAoLP(AoP) 79 | aop_img_sat = pa.applyColorToAoLP(AoP, saturation = DoP) 80 | cv2.imwrite(f'{instance_dir}/AoP_vis/{image_name_save}.png', aop_img) 81 | cv2.imwrite(f'{instance_dir}/AoP_vis_sat/{image_name_save}.png', aop_img_sat) 82 | 83 | view_dop(stokes[...,0], stokes[...,1], stokes[...,2],f'{instance_dir}/DoP_vis/{image_name_save}.png' ) 84 | 85 | if __name__ == '__main__': 86 | args = parser.parse_args() 87 | cal_stokes(args.instance_dir) -------------------------------------------------------------------------------- /dataio/polprocess/tools.py: -------------------------------------------------------------------------------- 1 | from re import S 2 | import numpy as np 3 | import cv2 4 | from .. import polanalyser as pa 5 | from skimage.transform import rescale 6 | import glob 7 | import imageio 8 | 9 | def print_np (x,show_print=False, show_shape =False, show_dtype=False, show_range=True): 10 | if show_print: 11 | print(x) 12 | if show_shape: 13 | print(x.shape) 14 | if show_dtype: 15 | print(x.dtype) 16 | if show_range: 17 | print('[', x.min(), ',', x.max(), ']') 18 | 19 | def linear_rescale(x,min =0, max =255): 20 | y = ((x - 0)/(x.max() - 0))*(max - min) + min 21 | return y 22 | 23 | def convert_u8(image, gamma=1/2.2): 24 | '''Gamma Correction''' 25 | image = np.clip(image, 0, 255).astype(np.uint8) 26 | lut = (255.0 * (np.linspace(0, 1, 256) ** gamma)).astype(np.uint8) 27 | return lut[image] 28 | 29 | def preprocess_raw_gray(img_raw, scale=1., thres=1.,depth=8): 30 | # import time 31 | # t = time.time() 32 | out = {} 33 | img_raw = img_raw.astype('float32')/(2**(depth)-1) 34 | img_raw = scale*img_raw # H x W float 35 | img_raw = np.minimum(img_raw, thres) 36 | img_pp = pa.demosaicing(img_raw) 37 | 38 | # import imageio; imageio.imwrite('viz/img_s0.png',img_demosaiced[...,0].astype('float32')/4096) 39 | 40 | 41 | angles = np.deg2rad([0, 45, 90, 135]) 42 | # elapsed=time.time() - t 43 | # print(f'Preprocessing time {elapsed}') 44 | # t = time.time() 45 | three_pinv = 1 46 | if not three_pinv: 47 | img_stokes_channel = (pa.calcStokes(img_pp, angles)) 48 | else: 49 | # Select angle that has highest intensity and remove it 50 | max_angle_ind = np.argmax(img_pp.sum(axis=(0,1,2))) 51 | img_pp_channel_rem = np.stack([img_pp[:,:,a] 52 | for a in range(len(angles)) 53 | if a != max_angle_ind],-1) 54 | angles_rem = [angles[a] for a in range(len(angles)) 55 | if a!= max_angle_ind] 56 | img_stokes_channel = pa.calcStokes(img_pp_channel_rem, angles_rem) 57 | 58 | # elapsed=time.time() - t 59 | # print(f'Initial separation time {elapsed}') 60 | out['stokes'] = img_stokes_channel #HxWx3(Stokes) 61 | 62 | return out 63 | 64 | def demosaic_color_and_upsample(img_raw): 65 | #img_raw: CPFA H x W x 3 66 | # may should be H x W ? 67 | #return: img_pfa_rgb: H x W x 3 x 4 68 | # may should be H//2 x W//2 x 3 x 4 and no upsampling 69 | height, width = img_raw.shape[:2] 70 | img_pfa_rgb = np.empty((height//2, width//2, 71 | 3,4), 72 | dtype=img_raw.dtype) 73 | for j in range(2): 74 | for i in range(2): 75 | # (i,j) 76 | # (0,0) is 90, (0,1) is 45 77 | # (1,0) is 135, (1,1) is 0 78 | 79 | # Downsampling by 2 80 | img_bayer_ij = img_raw[i::2, j::2] 81 | 82 | # Color correction 83 | # img_bayer_cc = np.clip(apply_cc_bayer(img_bayer_ij, 84 | # 'data/PMVIR_processed/ccmat.mat'), 85 | # 0,1) 86 | 87 | # Convert images to 16 bit 88 | img_bayer_16b = (img_bayer_ij*(2**16-1)).astype('uint16') 89 | # Color demosaicking 90 | img_rgb_ij_16b = cv2.cvtColor(img_bayer_16b, 91 | cv2.COLOR_BayerBG2RGB_EA) # Convert to 16bit and use edge aware demosaicking 92 | 93 | # Convert back to float 0, 1 94 | img_rgb_ij = img_rgb_ij_16b.astype('float32')/(2**16-1) 95 | 96 | # import imageio; imageio.imwrite('viz/pmvir_rgb/image_rgb_ij.exr',img_rgb_ij) 97 | # img_rgb_us = rescale(img_rgb_ij, 2, 98 | # anti_aliasing=False, 99 | # multichannel=True) 100 | img_rgb_us = img_rgb_ij 101 | # Save as stack 102 | img_pfa_rgb[:,:,:,2*i+j] = img_rgb_us # 90 45 135 0 h w 3 4 103 | 104 | # Upsampling 2 105 | img_pfa_rgb_cat = np.empty((height, width, 3),dtype = np.float32) 106 | for j in range(2): 107 | for i in range(2): 108 | img_pfa_rgb_cat[i::2,j::2] = img_pfa_rgb[:,:,:,2*i+j] 109 | rgb_dem_stack = [] 110 | for i in range(3): 111 | mono_dem = pa.demosaicing(img_pfa_rgb_cat[:,:,i]) 112 | rgb_dem_stack.append(mono_dem) 113 | rgb_dem = np.stack(rgb_dem_stack,-2) 114 | return rgb_dem 115 | 116 | def view_dop(s0, s1, s2, path): 117 | dop = np.sqrt(s1**2 + s2**2)/(s0+.0) 118 | cv2.imwrite(path, convert_u8(255*dop)) 119 | -------------------------------------------------------------------------------- /dataio/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/tools/__init__.py -------------------------------------------------------------------------------- /dataio/tools/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/tools/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dataio/tools/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/tools/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dataio/tools/__pycache__/vis_camera.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/tools/__pycache__/vis_camera.cpython-37.pyc -------------------------------------------------------------------------------- /dataio/tools/__pycache__/vis_camera.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/dataio/tools/__pycache__/vis_camera.cpython-38.pyc -------------------------------------------------------------------------------- /dataio/tools/vis_camera.py: -------------------------------------------------------------------------------- 1 | ''' 2 | camera extrinsics visualization tools 3 | modified from https://github.com/opencv/opencv/blob/master/samples/python/camera_calibration_show_extrinsics.py 4 | ''' 5 | 6 | from utils.print_fn import log 7 | # Python 2/3 compatibility 8 | 9 | import numpy as np 10 | import cv2 as cv 11 | 12 | from numpy import linspace 13 | import matplotlib 14 | 15 | # matplotlib.use('TkAgg') 16 | 17 | def inverse_homogeneoux_matrix(M): 18 | R = M[0:3, 0:3] 19 | T = M[0:3, 3] 20 | M_inv = np.identity(4) 21 | M_inv[0:3, 0:3] = R.T 22 | M_inv[0:3, 3] = -(R.T).dot(T) 23 | 24 | return M_inv 25 | 26 | 27 | def transform_to_matplotlib_frame(cMo, X, inverse=False): 28 | M = np.identity(4) 29 | M[1, 1] = 0 30 | M[1, 2] = 1 31 | M[2, 1] = -1 32 | M[2, 2] = 0 33 | 34 | if inverse: 35 | return M.dot(inverse_homogeneoux_matrix(cMo).dot(X)) 36 | else: 37 | return M.dot(cMo.dot(X)) 38 | 39 | 40 | def create_camera_model(camera_matrix, width, height, scale_focal, draw_frame_axis=False): 41 | fx = camera_matrix[0, 0] 42 | fy = camera_matrix[1, 1] 43 | focal = 2 / (fx + fy) 44 | f_scale = scale_focal * focal 45 | 46 | # draw image plane 47 | X_img_plane = np.ones((4, 5)) 48 | X_img_plane[0:3, 0] = [-width, height, f_scale] 49 | X_img_plane[0:3, 1] = [width, height, f_scale] 50 | X_img_plane[0:3, 2] = [width, -height, f_scale] 51 | X_img_plane[0:3, 3] = [-width, -height, f_scale] 52 | X_img_plane[0:3, 4] = [-width, height, f_scale] 53 | 54 | # draw triangle above the image plane 55 | X_triangle = np.ones((4, 3)) 56 | X_triangle[0:3, 0] = [-width, -height, f_scale] 57 | X_triangle[0:3, 1] = [0, -2*height, f_scale] 58 | X_triangle[0:3, 2] = [width, -height, f_scale] 59 | 60 | # draw camera 61 | X_center1 = np.ones((4, 2)) 62 | X_center1[0:3, 0] = [0, 0, 0] 63 | X_center1[0:3, 1] = [-width, height, f_scale] 64 | 65 | X_center2 = np.ones((4, 2)) 66 | X_center2[0:3, 0] = [0, 0, 0] 67 | X_center2[0:3, 1] = [width, height, f_scale] 68 | 69 | X_center3 = np.ones((4, 2)) 70 | X_center3[0:3, 0] = [0, 0, 0] 71 | X_center3[0:3, 1] = [width, -height, f_scale] 72 | 73 | X_center4 = np.ones((4, 2)) 74 | X_center4[0:3, 0] = [0, 0, 0] 75 | X_center4[0:3, 1] = [-width, -height, f_scale] 76 | 77 | # draw camera frame axis 78 | X_frame1 = np.ones((4, 2)) 79 | X_frame1[0:3, 0] = [0, 0, 0] 80 | X_frame1[0:3, 1] = [f_scale/2, 0, 0] 81 | 82 | X_frame2 = np.ones((4, 2)) 83 | X_frame2[0:3, 0] = [0, 0, 0] 84 | X_frame2[0:3, 1] = [0, f_scale/2, 0] 85 | 86 | X_frame3 = np.ones((4, 2)) 87 | X_frame3[0:3, 0] = [0, 0, 0] 88 | X_frame3[0:3, 1] = [0, 0, f_scale/2] 89 | 90 | if draw_frame_axis: 91 | return [X_img_plane, X_triangle, X_center1, X_center2, X_center3, X_center4, X_frame1, X_frame2, X_frame3] 92 | else: 93 | return [X_img_plane, X_triangle, X_center1, X_center2, X_center3, X_center4] 94 | 95 | 96 | def create_board_model(extrinsics, board_width, board_height, square_size, draw_frame_axis=False): 97 | width = board_width*square_size 98 | height = board_height*square_size 99 | 100 | # draw calibration board 101 | X_board = np.ones((4, 5)) 102 | #X_board_cam = np.ones((extrinsics.shape[0],4,5)) 103 | X_board[0:3, 0] = [0, 0, 0] 104 | X_board[0:3, 1] = [width, 0, 0] 105 | X_board[0:3, 2] = [width, height, 0] 106 | X_board[0:3, 3] = [0, height, 0] 107 | X_board[0:3, 4] = [0, 0, 0] 108 | 109 | # draw board frame axis 110 | X_frame1 = np.ones((4, 2)) 111 | X_frame1[0:3, 0] = [0, 0, 0] 112 | X_frame1[0:3, 1] = [height/2, 0, 0] 113 | 114 | X_frame2 = np.ones((4, 2)) 115 | X_frame2[0:3, 0] = [0, 0, 0] 116 | X_frame2[0:3, 1] = [0, height/2, 0] 117 | 118 | X_frame3 = np.ones((4, 2)) 119 | X_frame3[0:3, 0] = [0, 0, 0] 120 | X_frame3[0:3, 1] = [0, 0, height/2] 121 | 122 | if draw_frame_axis: 123 | return [X_board, X_frame1, X_frame2, X_frame3] 124 | else: 125 | return [X_board] 126 | 127 | 128 | def draw_camera(ax, camera_matrix, cam_width, cam_height, scale_focal, 129 | extrinsics, 130 | patternCentric=True, 131 | annotation=True): 132 | from matplotlib import cm 133 | 134 | min_values = np.zeros((3, 1)) 135 | min_values = np.inf 136 | max_values = np.zeros((3, 1)) 137 | max_values = -np.inf 138 | 139 | X_moving = create_camera_model( 140 | camera_matrix, cam_width, cam_height, scale_focal) 141 | 142 | cm_subsection = linspace(0.0, 1.0, extrinsics.shape[0]) 143 | colors = [cm.jet(x) for x in cm_subsection] 144 | 145 | for idx in range(extrinsics.shape[0]): 146 | # R, _ = cv.Rodrigues(extrinsics[idx,0:3]) 147 | # cMo = np.eye(4,4) 148 | # cMo[0:3,0:3] = R 149 | # cMo[0:3,3] = extrinsics[idx,3:6] 150 | cMo = extrinsics[idx] 151 | for i in range(len(X_moving)): 152 | X = np.zeros(X_moving[i].shape) 153 | for j in range(X_moving[i].shape[1]): 154 | X[0:4, j] = transform_to_matplotlib_frame( 155 | cMo, X_moving[i][0:4, j], patternCentric) 156 | ax.plot3D(X[0, :], X[1, :], X[2, :], color=colors[idx]) 157 | min_values = np.minimum(min_values, X[0:3, :].min(1)) 158 | max_values = np.maximum(max_values, X[0:3, :].max(1)) 159 | # modified: add an annotation of number 160 | if annotation: 161 | X = transform_to_matplotlib_frame( 162 | cMo, X_moving[0][0:4, 0], patternCentric) 163 | ax.text(X[0], X[1], X[2], "{}".format(idx), color=colors[idx]) 164 | 165 | return min_values, max_values 166 | 167 | 168 | def visualize(camera_matrix, extrinsics): 169 | 170 | ######################## plot params ######################## 171 | cam_width = 0.064/2 # Width/2 of the displayed camera. 172 | cam_height = 0.048/2 # Height/2 of the displayed camera. 173 | scale_focal = 40 # Value to scale the focal length. 174 | 175 | ######################## original code ######################## 176 | import matplotlib.pyplot as plt 177 | from mpl_toolkits.mplot3d import Axes3D # pylint: disable=unused-variable 178 | 179 | fig = plt.figure(figsize=(39.4,21.6)) 180 | ax = fig.add_subplot(projection='3d') 181 | # ax.set_aspect("equal") 182 | ax.set_aspect("auto") 183 | 184 | min_values, max_values = draw_camera(ax, camera_matrix, cam_width, cam_height, 185 | scale_focal, extrinsics, True) 186 | 187 | X_min = min_values[0] 188 | X_max = max_values[0] 189 | Y_min = min_values[1] 190 | Y_max = max_values[1] 191 | Z_min = min_values[2] 192 | Z_max = max_values[2] 193 | max_range = np.array([X_max-X_min, Y_max-Y_min, Z_max-Z_min]).max() / 2.0 194 | 195 | mid_x = (X_max+X_min) * 0.5 196 | mid_y = (Y_max+Y_min) * 0.5 197 | mid_z = (Z_max+Z_min) * 0.5 198 | ax.set_xlim(mid_x - max_range, mid_x + max_range) 199 | ax.set_ylim(mid_y - max_range, mid_y + max_range) 200 | ax.set_zlim(mid_z - max_range, mid_z + max_range) 201 | 202 | ax.set_xlabel('x') 203 | ax.set_ylabel('z') 204 | ax.set_zlabel('-y') 205 | ax.set_title('Extrinsic Parameters Visualization') 206 | 207 | plt.show() 208 | plt.savefig('./cameras.png') 209 | log.info('Done') 210 | 211 | 212 | if __name__ == '__main__': 213 | import argparse 214 | parser = argparse.ArgumentParser() 215 | parser.add_argument("--scan_id", type=int, default=40) 216 | args = parser.parse_args() 217 | 218 | log.info(__doc__) 219 | # NOTE: jianfei: 20210722 newly checked. The coordinate is correct. 220 | # note that the ticks on (-y) means the opposite of y coordinates. 221 | 222 | ######################## modified: example code ######################## 223 | from dataio.DTU import SceneDataset 224 | import torch 225 | train_dataset = SceneDataset( 226 | train_cameras=False, 227 | data_dir='./data/DTU/scan{}'.format(scan_id=args.scan_id)) 228 | c2w = torch.stack(train_dataset.c2w_all).data.cpu().numpy() 229 | extrinsics = np.linalg.inv(c2w) # camera extrinsics are w2c matrix 230 | camera_matrix = next(iter(train_dataset))[1]['intrinsics'].data.cpu().numpy() 231 | 232 | 233 | # import pickle 234 | # data = pickle.load(open('./dev_test/london/london_siren_si20_cam.pt', 'rb')) 235 | # c2ws = data['c2w'] 236 | # extrinsics = np.linalg.inv(c2ws) 237 | # camera_matrix = data['intr'] 238 | visualize(camera_matrix, extrinsics) 239 | cv.destroyAllWindows() 240 | -------------------------------------------------------------------------------- /dataio/tools/vis_ray.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | from utils.rend_util import get_rays 5 | from dataio.DTU import SceneDataset 6 | 7 | def plot_rays(rays_o: np.ndarray, rays_d: np.ndarray, ax): 8 | # TODO: automatic reducing number of rays 9 | XYZUVW = np.concatenate([rays_o, rays_d], axis=-1) 10 | X, Y, Z, U, V, W = np.transpose(XYZUVW) 11 | # X2 = X+U 12 | # Y2 = Y+V 13 | # Z2 = Z+W 14 | # x_max = max(np.max(X), np.max(X2)) 15 | # x_min = min(np.min(X), np.min(X2)) 16 | # y_max = max(np.max(Y), np.max(Y2)) 17 | # y_min = min(np.min(Y), np.min(Y2)) 18 | # z_max = max(np.max(Z), np.max(Z2)) 19 | # z_min = min(np.min(Z), np.min(Z2)) 20 | # fig = plt.figure() 21 | # ax = fig.add_subplot(111, projection='3d') 22 | ax.quiver(X, Y, Z, U, V, W) 23 | # ax.set_xlim(x_min, x_max) 24 | # ax.set_ylim(y_min, y_max) 25 | # ax.set_zlim(z_min, z_max) 26 | 27 | return ax 28 | 29 | dataset = SceneDataset(False, './data/DTU/scan40', downscale=32) 30 | 31 | fig = plt.figure() 32 | ax = fig.add_subplot(111, projection='3d') 33 | ax.set_xlabel('x') 34 | ax.set_ylabel('y') 35 | ax.set_zlabel('z') 36 | ax.set_xlim(-2, 2) 37 | ax.set_ylim(-2, 2) 38 | ax.set_zlim(-2, 2) 39 | H, W = (dataset.H, dataset.W) 40 | 41 | for i in range(dataset.n_images): 42 | _, model_input, _ = dataset[i] 43 | intrinsics = model_input["intrinsics"][None, ...] 44 | c2w = model_input['c2w'][None, ...] 45 | rays_o, rays_d, select_inds = get_rays(c2w, intrinsics, H, W, N_rays=1) 46 | rays_o = rays_o.data.squeeze(0).cpu().numpy() 47 | rays_d = rays_d.data.squeeze(0).cpu().numpy() 48 | ax = plot_rays(rays_o, rays_d, ax) 49 | 50 | plt.show() -------------------------------------------------------------------------------- /dataio/tools/vis_surface_and_cam.py: -------------------------------------------------------------------------------- 1 | from utils import io_util 2 | from dataio import get_data 3 | 4 | import skimage 5 | import skimage.measure 6 | import numpy as np 7 | import open3d as o3d 8 | 9 | 10 | def get_camera_frustum(img_size, K, W2C, frustum_length=0.5, color=[0., 1., 0.]): 11 | W, H = img_size 12 | hfov = np.rad2deg(np.arctan(W / 2. / K[0, 0]) * 2.) 13 | vfov = np.rad2deg(np.arctan(H / 2. / K[1, 1]) * 2.) 14 | half_w = frustum_length * np.tan(np.deg2rad(hfov / 2.)) 15 | half_h = frustum_length * np.tan(np.deg2rad(vfov / 2.)) 16 | 17 | # build view frustum for camera (I, 0) 18 | frustum_points = np.array([[0., 0., 0.], # frustum origin 19 | [-half_w, -half_h, frustum_length], # top-left image corner 20 | [half_w, -half_h, frustum_length], # top-right image corner 21 | [half_w, half_h, frustum_length], # bottom-right image corner 22 | [-half_w, half_h, frustum_length]]) # bottom-left image corner 23 | frustum_lines = np.array([[0, i] for i in range(1, 5)] + [[i, (i+1)] for i in range(1, 4)] + [[4, 1]]) 24 | frustum_colors = np.tile(np.array(color).reshape((1, 3)), (frustum_lines.shape[0], 1)) 25 | 26 | # frustum_colors = np.vstack((np.tile(np.array([[1., 0., 0.]]), (4, 1)), 27 | # np.tile(np.array([[0., 1., 0.]]), (4, 1)))) 28 | 29 | # transform view frustum from (I, 0) to (R, t) 30 | C2W = np.linalg.inv(W2C) 31 | frustum_points = np.dot(np.hstack((frustum_points, np.ones_like(frustum_points[:, 0:1]))), C2W.T) 32 | frustum_points = frustum_points[:, :3] / frustum_points[:, 3:4] 33 | 34 | return frustum_points, frustum_lines, frustum_colors 35 | 36 | def frustums2lineset(frustums): 37 | N = len(frustums) 38 | merged_points = np.zeros((N*5, 3)) # 5 vertices per frustum 39 | merged_lines = np.zeros((N*8, 2)) # 8 lines per frustum 40 | merged_colors = np.zeros((N*8, 3)) # each line gets a color 41 | 42 | for i, (frustum_points, frustum_lines, frustum_colors) in enumerate(frustums): 43 | merged_points[i*5:(i+1)*5, :] = frustum_points 44 | merged_lines[i*8:(i+1)*8, :] = frustum_lines + i*5 45 | merged_colors[i*8:(i+1)*8, :] = frustum_colors 46 | 47 | lineset = o3d.geometry.LineSet() 48 | lineset.points = o3d.utility.Vector3dVector(merged_points) 49 | lineset.lines = o3d.utility.Vector2iVector(merged_lines) 50 | lineset.colors = o3d.utility.Vector3dVector(merged_colors) 51 | 52 | return lineset 53 | 54 | 55 | # ---------------------- 56 | # plot cameras alongside with mesh 57 | # modified from NeRF++. https://github.com/Kai-46/nerfplusplus/blob/master/colmap_runner/extract_sfm.py 58 | def visualize_cameras(colored_camera_dicts, sphere_radius, camera_size=0.1, geometry_file=None, geometry_type='mesh', backface=False): 59 | things_to_draw = [] 60 | 61 | if sphere_radius > 0: 62 | sphere = o3d.geometry.TriangleMesh.create_sphere(radius=sphere_radius, resolution=10) 63 | sphere = o3d.geometry.LineSet.create_from_triangle_mesh(sphere) 64 | sphere.paint_uniform_color((1, 0, 0)) 65 | things_to_draw.append(sphere) 66 | 67 | coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0., 0., 0.]) 68 | things_to_draw.append(coord_frame) 69 | 70 | idx = 0 71 | for camera_dict in colored_camera_dicts: 72 | idx += 1 73 | 74 | K = np.array(camera_dict['K']).reshape((4, 4)) 75 | W2C = np.array(camera_dict['W2C']).reshape((4, 4)) 76 | C2W = np.linalg.inv(W2C) 77 | img_size = camera_dict['img_size'] 78 | color = camera_dict['color'] 79 | frustums = [get_camera_frustum(img_size, K, W2C, frustum_length=camera_size, color=color)] 80 | cameras = frustums2lineset(frustums) 81 | things_to_draw.append(cameras) 82 | 83 | if geometry_file is not None: 84 | if geometry_type == 'mesh': 85 | geometry = o3d.io.read_triangle_mesh(geometry_file) 86 | geometry.compute_vertex_normals() 87 | elif geometry_type == 'pointcloud': 88 | geometry = o3d.io.read_point_cloud(geometry_file) 89 | else: 90 | raise Exception('Unknown geometry_type: ', geometry_type) 91 | 92 | things_to_draw.append(geometry) 93 | if backface: 94 | o3d.visualization.RenderOption.mesh_show_back_face = True 95 | o3d.visualization.draw_geometries(things_to_draw) 96 | 97 | 98 | if __name__ == "__main__": 99 | parser = io_util.create_args_parser() 100 | parser.add_argument("--scan_id", type=int, default=40) 101 | parser.add_argument("--mesh_file", type=str, default=None) 102 | parser.add_argument("--sphere_radius", type=float, default=3.0) 103 | parser.add_argument("--backface",action='store_true', help='render show back face') 104 | args = parser.parse_args() 105 | 106 | # load camera 107 | args, unknown = parser.parse_known_args() 108 | config = io_util.load_config(args, unknown) 109 | dataset = get_data(config) 110 | 111 | #------------- 112 | colored_camera_dicts = [] 113 | for i in range(len(dataset)): 114 | (_, model_input, ground_truth) = dataset[i] 115 | c2w = model_input['c2w'].data.cpu().numpy() 116 | intrinsics = model_input["intrinsics"].data.cpu().numpy() 117 | 118 | cam_dict = {} 119 | cam_dict['img_size'] = (dataset.W, dataset.H) 120 | cam_dict['W2C'] = np.linalg.inv(c2w) 121 | cam_dict['K'] = intrinsics 122 | # cam_dict['color'] = [0, 1, 1] 123 | cam_dict['color'] = [1, 0, 0] 124 | 125 | # if i == 0: 126 | # cam_dict['color'] = [1, 0, 0] 127 | 128 | # if i == 1: 129 | # cam_dict['color'] = [0, 1, 0] 130 | 131 | # if i == 28: 132 | # cam_dict['color'] = [1, 0, 0] 133 | 134 | colored_camera_dicts.append(cam_dict) 135 | 136 | visualize_cameras(colored_camera_dicts, args.sphere_radius, geometry_file=args.mesh_file, backface=args.backface) 137 | 138 | 139 | -------------------------------------------------------------------------------- /docs/normal_splatting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/docs/normal_splatting.png -------------------------------------------------------------------------------- /docs/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/docs/pipeline.png -------------------------------------------------------------------------------- /models/PolAnalyser.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def applyColorToAoLP(img_AoLP, saturation=1.0, value=1.0): 7 | """ 8 | Apply color map to AoLP image 9 | The color map is based on HSV 10 | Parameters 11 | ---------- 12 | img_AoLP : np.ndarray, (height, width) 13 | AoLP image. The range is from 0.0 to pi. 14 | 15 | saturation : float or np.ndarray, (height, width) 16 | Saturation part (optional). 17 | If you pass DoLP image (img_DoLP) as an argument, you can modulate it by DoLP. 18 | value : float or np.ndarray, (height, width) 19 | Value parr (optional). 20 | If you pass DoLP image (img_DoLP) as an argument, you can modulate it by DoLP. 21 | """ 22 | img_ones = np.ones_like(img_AoLP) 23 | 24 | img_hue = (np.mod(img_AoLP, np.pi)/np.pi*179).astype(np.uint8) # 0~pi -> 0~179 25 | img_saturation = np.clip(img_ones*saturation*255, 0, 255).astype(np.uint8) 26 | img_value = np.clip(img_ones*value*255, 0, 255).astype(np.uint8) 27 | 28 | img_hsv = cv2.merge([img_hue, img_saturation, img_value]) 29 | img_bgr = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR) 30 | img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) 31 | return img_rgb 32 | 33 | def batchApplyColorToAoLP(img_AoLP, saturation=1.0, value=1.0): 34 | """ 35 | Apply color map to AoLP image 36 | The color map is based on HSV 37 | Parameters 38 | ---------- 39 | img_AoLP : np.ndarray, (B, height, width) 40 | AoLP image. The range is from 0.0 to pi. 41 | 42 | saturation : float or np.ndarray, (height, width) 43 | Saturation part (optional). 44 | If you pass DoLP image (img_DoLP) as an argument, you can modulate it by DoLP. 45 | value : float or np.ndarray, (height, width) 46 | Value parr (optional). 47 | If you pass DoLP image (img_DoLP) as an argument, you can modulate it by DoLP. 48 | """ 49 | imgs_rgb = np.zeros((*img_AoLP.shape,3)) 50 | for i in range(img_AoLP.shape[0]): 51 | if type(saturation) is float: 52 | imgs_rgb[i,...] = applyColorToAoLP(img_AoLP[i,...],saturation=saturation) 53 | else: 54 | imgs_rgb[i,...] = applyColorToAoLP(img_AoLP[i,...],saturation=saturation[i,...]) 55 | 56 | return imgs_rgb 57 | 58 | 59 | def normal_to_aop(normal_map_cam, opengl = False): 60 | '''From normals to predicted aop 61 | Args: 62 | [N_rays, 3] 63 | Return: 64 | [N_rays] 65 | ''' 66 | 67 | phi = torch.atan2(normal_map_cam[...,1], normal_map_cam[...,0]) # N_batch x N_rays (rad) 68 | 69 | phi_to_aop = np.pi/2 + phi if opengl else np.pi/2 - phi 70 | phi_to_aop = torch.remainder(phi_to_aop, np.pi) 71 | 72 | return phi_to_aop 73 | 74 | def world_normal_to_aop(pose, 75 | normal_map): 76 | '''From normals to predicted aop 77 | 78 | Return: 79 | [B, N_rays] 80 | ''' 81 | 82 | w2c = pose[:, :3,:3].transpose(1, 2) # R^T, B x 3 x 3 83 | # check_np(w2c) 84 | 85 | (N_batch, N_rays, _) = normal_map.shape 86 | 87 | # N_samples = normal_map.shape[0] // batch_size 88 | 89 | # print('get_AoP_loss', 'N_samples', N_samples) 90 | # print('get_AoP_loss', 'Batch_size', batch_size) 91 | 92 | # B x S --> S ONLY work if Batch Size is 1 93 | # AoP_gt = AoP_gt[0] 94 | # DoP_gt = DoP_gt[0] 95 | # mask = mask[0] 96 | 97 | # print('DoP_gt', DoP_gt.shape) 98 | # print('AoP_gt', AoP_gt.shape) 99 | 100 | 101 | # normal_map = normal_map.reshape([N_batch, 3, N_samples]) 102 | normal_map = normal_map.transpose(1, 2) # [B, 3, N_rays] 103 | normal_map_cam = torch.bmm(w2c, normal_map) # B x 3 x 3 @ B x 3 x N_rays = B x 3 x N_rays 104 | normal_map_cam = normal_map_cam.transpose(1, 2) # B x N_rays x 3 105 | phi = torch.atan2(normal_map_cam[...,1], normal_map_cam[...,0]) # N_batch x N_rays (rad) 106 | 107 | # MOD: PMIVR Loss Deprecated 108 | # eta = torch.stack([torch.abs(phi-AoP_gt-np.pi/2), torch.abs(phi-AoP_gt), torch.abs(phi-AoP_gt+np.pi/2), torch.abs(phi-AoP_gt+np.pi)], dim=1) 109 | # eta, _ = torch.min(eta, dim=1) 110 | 111 | phi_to_aop = np.pi/2 + phi 112 | # mod to [0, pi] 113 | phi_to_aop = torch.remainder(phi_to_aop, np.pi) 114 | 115 | return phi_to_aop 116 | 117 | def normal_to_dop(pose, normal_map): 118 | 119 | w2c = pose[:, :3,:3].transpose(1, 2) # R^T, B x 3 x 3 120 | # check_np(w2c) 121 | 122 | (N_batch, N_rays, _) = normal_map.shape 123 | 124 | # N_samples = normal_map.shape[0] // batch_size 125 | 126 | # print('get_AoP_loss', 'N_samples', N_samples) 127 | # print('get_AoP_loss', 'Batch_size', batch_size) 128 | 129 | # B x S --> S ONLY work if Batch Size is 1 130 | # AoP_gt = AoP_gt[0] 131 | # DoP_gt = DoP_gt[0] 132 | # mask = mask[0] 133 | 134 | # print('DoP_gt', DoP_gt.shape) 135 | # print('AoP_gt', AoP_gt.shape) 136 | 137 | 138 | # normal_map = normal_map.reshape([N_batch, 3, N_samples]) 139 | normal_map = normal_map.transpose(1, 2) # [B, 3, N_rays] 140 | normal_map_cam = torch.bmm(w2c, normal_map) # B x 3 x 3 @ B x 3 x N_rays = B x 3 x N_rays 141 | normal_map_cam = normal_map_cam.transpose(1, 2) # B x N_rays x 3 142 | image_plane_norm = normal_map_cam[...,:-1].norm(2,-1) # [B, N_rays] 143 | zenith_angle = torch.atan2(image_plane_norm, normal_map_cam[...,2]) 144 | # DEBUG 145 | # print('zenith_angle', zenith_angle.min(), zenith_angle.max()) 146 | n_ref = 1.5 * torch.ones_like(zenith_angle) # Refraction Index Default 1.5 147 | s2 = torch.sin(zenith_angle).pow(2) 148 | s4 = s2.pow(2) 149 | c2 = torch.cos(zenith_angle).pow(2) 150 | dop = torch.sqrt(s4*c2*(n_ref.pow(2)-s2))/((s4 +c2*(n_ref.pow(2) - s2))/2) 151 | 152 | return dop 153 | 154 | 155 | def check_np(x:np.ndarray): 156 | print(x.shape) 157 | print(f'[{x.min():02f},{x.max():02f}]') 158 | 159 | def visualize_aop(aop, dop=1.0): 160 | '''AoP -> HSV -> RGB 161 | 162 | Args: 163 | aop, dop (N,) 164 | ''' 165 | ones = torch.ones_like(aop) 166 | 167 | hue = (torch.remainder(aop, np.pi)/np.pi*179).type(torch.uint8).cpu().numpy() # 0~pi -> 0~179 168 | saturation = torch.clamp(ones*dop*255, 0, 255).type(torch.uint8).cpu().numpy() 169 | value = torch.clamp(ones*255, 0, 255).type(torch.uint8).cpu().numpy() 170 | 171 | img_hsv = cv2.merge([hue, saturation, value]) 172 | img_rgb = cv2.cvtColor(img_hsv, cv2.COLOR_HSV2RGB) 173 | return torch.from_numpy(img_rgb)/255. -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/models/__init__.py -------------------------------------------------------------------------------- /models/cameras.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import plotly.offline as offline 4 | import plotly.graph_objs as go 5 | import torch.nn.functional as F 6 | 7 | def perspective_opengl(cx, cy, n, f): 8 | """Camera Space to Orthogonal Space 9 | 10 | Args: 11 | cx: (r-l)/2 12 | cy: (t-b)/2 13 | n: near >0 14 | f: far >0 15 | """ 16 | return torch.Tensor([[n/cx,0,0,0], 17 | [0,n/cy,0,0], 18 | [0,0,-(f+n)/(f-n), -2*n*f/(f-n)], 19 | [0,0,-1,0]]).type_as(n) 20 | 21 | def _batch_linear_transform(A, x): 22 | """batched matrix multiplication 23 | 24 | Args: 25 | A: [m,n] 26 | x: [N,n] 27 | """ 28 | return (A@(x[...,None])).squeeze(-1) 29 | 30 | def _cart_to_homo(x, w=1.0): 31 | return torch.cat([x, w*torch.ones_like(x[...,:1])],-1) 32 | 33 | def _homo_to_cart(x): 34 | return (x/x[...,3:])[...,:3] 35 | 36 | def perspective_to_orthogonal(pts, near, far): 37 | """Perspective to orthogonal 38 | 39 | Args: 40 | pts: N,3 41 | near: >0 42 | far: >0 43 | """ 44 | trans_matrix = torch.Tensor([[near,0,0,0],[0, near, 0, 0],[0,0,near+far, near*far],[0,0,-1,0]]).type_as(pts) 45 | return _homo_to_cart(_batch_linear_transform(trans_matrix, _cart_to_homo(pts))) 46 | 47 | def perspective_to_orthogonal_homo_coord(pts, near, far): 48 | """Perspective to orthogonal 49 | 50 | Args: 51 | pts: N,3 52 | near: >0 53 | far: >0 54 | """ 55 | trans_matrix = torch.Tensor([[near,0,0,0],[0, near, 0, 0],[0,0,near+far, near*far],[0,0,-1,0]]).type_as(pts) 56 | return (_batch_linear_transform(trans_matrix, _cart_to_homo(pts))) 57 | 58 | def orthogonal_to_perspective(pts, near, far): 59 | """Perspective to orthogonal 60 | 61 | Args: 62 | pts: N,3 63 | near: >0 64 | far: >0 65 | """ 66 | #NOTE: torch.inverse() doesn't support multi-threading, see https://github.com/pytorch/pytorch/issues/90613, use hard code 67 | trans_matrix = torch.Tensor([[1/near,0,0,0],[0, 1/near, 0, 0],[0,0,0, -1],[0,0,1/(near*far),(near+far)/(near*far)]]).type_as(pts) 68 | return _homo_to_cart(_batch_linear_transform(trans_matrix, _cart_to_homo(pts))) 69 | 70 | def orthogonal_to_perpective_homo_coord(pts, near, far): 71 | """Perspective to orthogonal 72 | 73 | Args: 74 | pts: N,4 75 | near: >0 76 | far: >0 77 | """ 78 | trans_matrix = torch.Tensor([[near,0,0,0],[0, near, 0, 0],[0,0,near+far, near*far],[0,0,-1,0]]).type_as(pts).inverse() 79 | return _homo_to_cart((_batch_linear_transform(trans_matrix, pts))) 80 | 81 | def perspective_to_ray_space(pts): 82 | """Used in https://www.cs.umd.edu/~zwicker/publications/EWAVolumeSplatting-VIS01.pdf 83 | 84 | Args: 85 | pts: [N, 3] 86 | 87 | Returns: 88 | pts 89 | """ 90 | z = pts.norm(p=2,dim=-1,keepdim=False) 91 | pts[:,0] /= pts[:,2] 92 | pts[:,1] /= pts[:,2] 93 | pts[:,2] = z 94 | return pts 95 | 96 | def rays_uni_sample(rays_o, rays_d, near = None, far = None, samples = 10): 97 | if far is not None: 98 | t_ = far * torch.ones_like(rays_o[...,0]) 99 | else: 100 | t_= 10. * torch.ones_like(rays_o[...,0]) 101 | if near is not None: 102 | rays_o = rays_o + near * rays_d 103 | k = torch.linspace(0,1,samples).type_as(rays_o) 104 | rays_t = (rays_o[None,...] + k[...,None,None] * t_[None,...,None] * rays_d[None,...]).reshape(-1,3) 105 | 106 | return rays_t 107 | 108 | def world_to_camera(c2w, pts): 109 | R_t = c2w[:3,:3].T 110 | t = - R_t @ c2w[:3,3] 111 | return _batch_linear_transform(R_t, pts) + t 112 | 113 | def world_to_camera_orient(c2w, pts): 114 | R_t = c2w[:3,:3].T 115 | return _batch_linear_transform(R_t, pts) 116 | 117 | def camera_projection( 118 | normal_map, 119 | pose): 120 | w2c = pose[:, :3,:3].transpose(1, 2) # R^T, B x 3 x 3 121 | (N_batch, N_rays,_) = normal_map.shape 122 | normal_map = normal_map.transpose(1, 2) # [B, 3, N_rays] 123 | normal_map_cam = torch.bmm(w2c, normal_map) # B x 3 x 3 @ B x 3 x N_rays = B x 3 x N_rays 124 | normal_map_cam = normal_map_cam.transpose(1, 2) # B x N_rays x 3 125 | return normal_map_cam 126 | 127 | def plot_pts(pts, filename='points'): 128 | _pts = pts.cpu().numpy() 129 | pts = go.Scatter3d(x=_pts[:,0], 130 | y=_pts[:,1], 131 | z=_pts[:,2], 132 | mode='markers', 133 | marker=dict(size=1)) 134 | fig = go.Figure(data=[pts]) 135 | offline.plot(fig, filename=f'{filename}.html', auto_open=False) 136 | 137 | def plot_pts_with_neighborhood(pts, pts_2, filename='points'): 138 | _pts = pts.cpu().numpy() 139 | _pts_2 = pts_2.cpu().numpy() 140 | pts = go.Scatter3d(x=_pts[:,0], 141 | y=_pts[:,1], 142 | z=_pts[:,2], 143 | mode='markers', 144 | marker=dict(size=1)) 145 | pts_2 = go.Scatter3d(x=_pts_2[:,0], 146 | y=_pts_2[:,1], 147 | z=_pts_2[:,2], 148 | mode='markers', 149 | marker=dict(size=1, color='yellow')) 150 | fig = go.Figure(data=[pts, pts_2]) 151 | offline.plot(fig, filename=f'{filename}.html', auto_open=False) 152 | 153 | def plot_normals(normals, pts, filename= 'normals', interval = None): 154 | _pts, _normals = pts.cpu().numpy(), F.normalize(normals).cpu().numpy() 155 | if interval is not None: 156 | _pts, _normals = _pts[::interval], _normals[::interval] 157 | xyz_range = _pts.max() - _pts.min() 158 | step_size = xyz_range / 50 159 | pts_end = _pts + step_size * _normals 160 | _pts = np.stack([_pts, pts_end], axis = -1) 161 | lines = [] 162 | colors = ((_normals + 1) / 2 * 255).astype('u1') 163 | for i in range(_pts.shape[0]): 164 | lines.append(go.Scatter3d(x = [_pts[i,0,0],_pts[i,0,1]], 165 | y = [_pts[i,1,0],_pts[i,1,1]], 166 | z = [_pts[i,2,0],_pts[i,2,1]], 167 | mode = 'lines', 168 | # marker = dict(showscale=False), 169 | line = dict(color='rgb({},{},{})'.format(*colors[i]), 170 | width=5)) 171 | ) 172 | dists = _pts[:,:,1] - _pts[:,:,0] 173 | cones = go.Cone(x=_pts[:,0,1], 174 | y=_pts[:,1,1], 175 | z=_pts[:,2,1], 176 | u = 0.3 * dists[:,0], 177 | v = 0.3 * dists[:,1], 178 | w = 0.3 * dists[:,2], 179 | sizeref=0.5, 180 | name='cones', 181 | showscale=False) 182 | 183 | fig = go.Figure(data=lines + [cones]) 184 | fig.update_layout(showlegend=False) 185 | offline.plot(fig, filename=f'{filename}.html', auto_open=False) 186 | 187 | def plot_vector_field(normals, pts, filename= 'vector_field', interval = None): 188 | _pts, _normals = pts.cpu().numpy(), F.normalize(normals).cpu().numpy() 189 | if interval is not None: 190 | _pts, _normals = _pts[::interval], _normals[::interval] 191 | xyz_range = _pts.max() - _pts.min() 192 | step_size = xyz_range / 50 193 | pts_end = _pts + step_size * _normals 194 | _pts = np.stack([_pts, pts_end], axis = -1) 195 | lines = [] 196 | colors = ((_normals + 1) / 2 * 255).astype('u1') 197 | for i in range(_pts.shape[0]): 198 | lines.append(go.Scatter3d(x = [_pts[i,0,0],_pts[i,0,1]], 199 | y = [_pts[i,1,0],_pts[i,1,1]], 200 | z = [_pts[i,2,0],_pts[i,2,1]], 201 | mode = 'lines', 202 | # marker = dict(showscale=False), 203 | line = dict(color='rgb({},{},{})'.format(*colors[i]), 204 | width=5)) 205 | ) 206 | dists = _pts[:,:,1] - _pts[:,:,0] 207 | cones = go.Cone(x=_pts[:,0,1], 208 | y=_pts[:,1,1], 209 | z=_pts[:,2,1], 210 | u = 0.3 * dists[:,0], 211 | v = 0.3 * dists[:,1], 212 | w = 0.3 * dists[:,2], 213 | sizeref=0.5, 214 | name='cones', 215 | showscale=False) 216 | points_plot = go.Scatter3d(x = _pts[:,0,0], 217 | y = _pts[:,1,0], 218 | z = _pts[:,2,0], 219 | mode = 'markers', 220 | marker = dict(color = _pts[:,2,0], size = 3)) 221 | fig = go.Figure(data=lines + [cones] + [points_plot]) 222 | fig.update_layout(showlegend=False) 223 | offline.plot(fig, filename=f'{filename}.html', auto_open=False) 224 | 225 | def create_step_vectors(pts): 226 | ones_vec = torch.ones([3,4]).type_as(pts) 227 | ones_vec[0,2:] = -1 228 | ones_vec[1,1::2] = -1 229 | ones_vec[2,:] = 0 230 | return [ones_vec[:,i] for i in range(4)] 231 | 232 | def create_step_vectors_2x(pts): 233 | ones_vec = torch.ones([3,8]).type_as(pts) 234 | ones_vec[[0,0],[2,3]] = -1 235 | ones_vec[[1,1],[1,3]] = -1 236 | ones_vec[:,4:] = 0.5*ones_vec[:,:4] 237 | ones_vec[2,:] = 0 238 | return [ones_vec[:,i] for i in range(8)] 239 | 240 | def create_step_vectors_4x(pts): 241 | ones_vec = torch.ones([3,16]).type_as(pts) 242 | ones_vec[[0,0],[2,3]] = -1 243 | ones_vec[[1,1],[1,3]] = -1 244 | ones_vec[0,6:] = 0 245 | ones_vec[0,5] = -1 246 | ones_vec[1,4:6] = 0 247 | ones_vec[1,7] = -1 248 | ones_vec[2,:] = 0 249 | return [ones_vec[:,i] for i in range(8)] 250 | 251 | def indexing_2d_samples(select_inds, H, W, scale_pixel = 1): 252 | _bound = H * W - 1 253 | _up = select_inds - scale_pixel * W 254 | _up = torch.where(_up > 0, _up, select_inds) 255 | _down = select_inds + scale_pixel * W 256 | _down = torch.where(_down < _bound, _down, select_inds) 257 | _right = select_inds + scale_pixel 258 | _right = torch.where(_right < _bound, _right, select_inds) 259 | _left = select_inds - scale_pixel 260 | _left = torch.where(_left > 0, _left, select_inds) 261 | return torch.stack([_up, _down, _left, _right], -1) # [1, N_rays, 4] 262 | 263 | def congruent_transform(sigma, V): 264 | V = V.expand_as(sigma) 265 | return V @ sigma @ V.transpose(-1,-2) -------------------------------------------------------------------------------- /models/frameworks/__init__.py: -------------------------------------------------------------------------------- 1 | def get_model(args): 2 | if args.model.framework == 'UNISURF': 3 | from .unisurf import get_model 4 | elif args.model.framework == 'NeuS': 5 | from .neus import get_model 6 | elif args.model.framework == 'VolSDF': 7 | from .volsdf import get_model 8 | elif args.model.framework == 'PVolSDF': 9 | from .pvolsdf_sRGB import get_model 10 | elif args.model.framework == 'PVolSDFMono': 11 | from .pvolsdf_mono import get_model 12 | elif args.model.framework == 'PNeuS': 13 | from .pneus import get_model 14 | elif args.model.framework == 'SSL-PNeuS': 15 | from .sslpneus import get_model 16 | elif args.model.framework == 'SSL-PVolSDF': 17 | from .sslpvolsdf import get_model 18 | elif args.model.framework == 'holoNeuS': 19 | from .holoNeuS import get_model 20 | elif args.model.framework == 'pnr': 21 | from .pnr import get_model 22 | else: 23 | raise NotImplementedError 24 | return get_model(args) -------------------------------------------------------------------------------- /models/frameworks/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/models/frameworks/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /models/frameworks/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/models/frameworks/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/frameworks/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/models/frameworks/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/frameworks/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/models/frameworks/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /models/frameworks/__pycache__/holoNeuS.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/models/frameworks/__pycache__/holoNeuS.cpython-37.pyc -------------------------------------------------------------------------------- /models/frameworks/__pycache__/neus.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/models/frameworks/__pycache__/neus.cpython-37.pyc -------------------------------------------------------------------------------- /models/frameworks/__pycache__/pneus.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/models/frameworks/__pycache__/pneus.cpython-37.pyc -------------------------------------------------------------------------------- /models/frameworks/__pycache__/pneus.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/models/frameworks/__pycache__/pneus.cpython-38.pyc -------------------------------------------------------------------------------- /models/frameworks/__pycache__/pnr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/models/frameworks/__pycache__/pnr.cpython-38.pyc -------------------------------------------------------------------------------- /models/frameworks/__pycache__/pvolsdf.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/models/frameworks/__pycache__/pvolsdf.cpython-37.pyc -------------------------------------------------------------------------------- /models/frameworks/__pycache__/pvolsdf_gray.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/models/frameworks/__pycache__/pvolsdf_gray.cpython-37.pyc -------------------------------------------------------------------------------- /models/frameworks/__pycache__/pvolsdf_mono.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/models/frameworks/__pycache__/pvolsdf_mono.cpython-37.pyc -------------------------------------------------------------------------------- /models/frameworks/__pycache__/pvolsdf_sRGB.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/models/frameworks/__pycache__/pvolsdf_sRGB.cpython-37.pyc -------------------------------------------------------------------------------- /models/frameworks/__pycache__/sslpneus.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/models/frameworks/__pycache__/sslpneus.cpython-37.pyc -------------------------------------------------------------------------------- /models/frameworks/__pycache__/sslpvolsdf.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/models/frameworks/__pycache__/sslpvolsdf.cpython-37.pyc -------------------------------------------------------------------------------- /models/frameworks/__pycache__/unisurf.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/models/frameworks/__pycache__/unisurf.cpython-37.pyc -------------------------------------------------------------------------------- /models/frameworks/__pycache__/volsdf.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/models/frameworks/__pycache__/volsdf.cpython-37.pyc -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | '''From Polarized VolSDF, define AoP Loss 2 | ''' 3 | import torch 4 | from torch import nn 5 | import utils.general as utils 6 | 7 | import numpy as np 8 | from collections import OrderedDict 9 | import torch.nn.functional as F 10 | from .math_utils import two_order_real_symmetric_matrix_svd 11 | 12 | def covariance_to_correlation(cov): 13 | """2d covariance -> correlation 14 | sigma(x_1, x_2)/(sigma(x_1)sigma(x_2)) 15 | Args: 16 | cov: _, 2, 2 17 | """ 18 | assert cov.shape[-2] == cov.shape[-1] and cov.shape[-1] == 2 19 | print('cov', cov.abs().min()) 20 | std_x, std_y = (cov[...,0,0] + 1e-7).sqrt(), (cov[...,1,1]+ 1e-7).sqrt() 21 | # cov[...,0,0] /= std_x * std_x 22 | cov[...,0,0] = 1. 23 | cov[...,0,1] /= (std_x * std_y) 24 | cov[...,1,0] /= (std_y * std_x) 25 | print('divisor', (std_x * std_y).min()) 26 | # cov[...,1,1] /= std_y * std_y 27 | cov[...,1,1] = 1. 28 | return cov 29 | 30 | 31 | class polLoss(nn.Module): 32 | def __init__(self, args): 33 | super().__init__() 34 | self.w = args 35 | 36 | def forward(self, model_outputs, ground_truth, iteration): 37 | 38 | rgb_pred = model_outputs['rgb'] 39 | rgb_gt = ground_truth['rgb'] 40 | AoP_gt = ground_truth['AoP_map'] 41 | DoP_gt = ground_truth['DoP_map'] 42 | mask = ground_truth['mask'] 43 | grad_norm = model_outputs['grad_norm'] 44 | mask_ignore = ground_truth['mask_ignore'] 45 | normal_map_cam = model_outputs['normals_rayspace'] if self.w.normal_perspective else model_outputs['normals_ortho'] 46 | 47 | losses = OrderedDict() 48 | 49 | losses['loss_img'] = F.l1_loss(rgb_pred, rgb_gt, reduction='none') 50 | losses['loss_img'] = self.w.w_rgb * losses['loss_img'] 51 | 52 | if self.w.pol_rew and iteration > self.w.pol_start_it: 53 | dop_w = torch.clamp(DoP_gt, min = 0.0, max = self.w.dop_upper) if self.w.dop_upper > 0 else DoP_gt 54 | losses['loss_img'] = (1-dop_w[...,None]) * losses['loss_img'] 55 | if self.w.w_mask > 0: 56 | losses['loss_img'] = (losses['loss_img'] * mask[..., None].float()).sum() / (mask.sum() + 1e-10) 57 | losses['loss_mask'] = self.w.w_mask * F.binary_cross_entropy(model_outputs['mask_volume'], mask.float(), reduction='mean') 58 | else: 59 | losses['loss_img'] = losses['loss_img'].mean() 60 | 61 | if self.w.w_eik > 0: 62 | losses['loss_eikonal'] = self.w.w_eik * F.mse_loss(grad_norm, grad_norm.new_ones(grad_norm.shape), reduction='mean') 63 | 64 | if self.w.w_aop > 0 and iteration > self.w.pol_start_it: 65 | azi_angle = torch.atan2(normal_map_cam[...,1], normal_map_cam[...,0] + 1e-10) # N_batch x N_rays (rad) [-pi,pi] 66 | aop_pred = torch.remainder(np.pi/2 + azi_angle, np.pi) if self.w.opengl else torch.remainder(np.pi/2 - azi_angle, np.pi) 67 | eta = F.l1_loss(aop_pred, AoP_gt, reduction='none') 68 | _mask = mask * (~mask_ignore) if mask_ignore is not None else mask 69 | if self.w.pol_rew: 70 | AoP_loss = (DoP_gt * _mask.float() * eta).sum()/ (_mask.sum() + 1e-10) if self.w.aop_mask else (DoP_gt * eta).mean() 71 | else: 72 | AoP_loss = (_mask.float() * eta).sum()/ (_mask.sum() + 1e-10) if self.w.aop_mask else (DoP_gt * eta).mean() 73 | losses['loss_aop'] = self.w.w_aop * AoP_loss 74 | 75 | if self.w.w_splat > 0 and iteration > self.w.splat_start_it: 76 | _norm_scale = normal_map_cam.norm(dim = -1)[...,None] 77 | normals_aop_mean = _norm_scale * torch.stack([torch.sin(AoP_gt), -torch.cos(AoP_gt)], dim = -1) 78 | normals_aop_samples=_norm_scale[...,None] * torch.stack([torch.sin(ground_truth['aop_samples']), -torch.cos(ground_truth['aop_samples'])], dim = -1) 79 | normals_dop_samples= ground_truth['dop_samples'] 80 | normals_aop_samples = (normals_aop_samples - normals_aop_mean[...,None,:]) 81 | normals_aop_cov = (normals_aop_samples.transpose(-1,-2) @ normals_aop_samples)/3 82 | normals_image_cov = model_outputs['normals_image_cov'] 83 | if self.w.get('svd_sup', False): 84 | #----- test: SVD Sup. -------- 85 | # NOTE: linalg.svd Cause backward NaN 86 | # img_svd_vec = torch.linalg.svd(normals_image_cov , driver = "gesvd" )[0] 87 | # img_svd_val = torch.linalg.svd(normals_image_cov, driver = "gesvd" )[1] 88 | # aop_svd_vec = torch.linalg.svd(normals_aop_cov, driver = "gesvd")[0] 89 | img_svd_val, img_svd_vec = two_order_real_symmetric_matrix_svd(normals_image_cov) 90 | aop_svd_val, aop_svd_vec = two_order_real_symmetric_matrix_svd(normals_aop_cov) 91 | anistropic_img = (img_svd_val[...,1] / (img_svd_val[...,0] + 1e-8)) * mask.float() 92 | anistropic_aop = (aop_svd_val[...,1] / (aop_svd_val[...,0] + 1e-8)) * mask.float() 93 | anistropic_weight = torch.ones_like(anistropic_aop) 94 | scale_factor = 1/10 95 | anistropic_weight *= scale_factor 96 | vec_orientation_similarity = F.cosine_similarity(img_svd_vec, aop_svd_vec, dim=-1).abs().sum(dim=-1)/2 97 | eta = F.l1_loss(anistropic_img, anistropic_aop, reduction='none')* mask.float() 98 | eta += F.l1_loss(vec_orientation_similarity, torch.ones_like(vec_orientation_similarity), reduction = 'none') * mask.float() * anistropic_weight 99 | else: 100 | eta = F.l1_loss(normals_image_cov, normals_aop_cov, reduction='none')* mask.float() 101 | if self.w.splat_rew: 102 | dop_w = normals_dop_samples.mean(dim=-1) 103 | eta *= dop_w 104 | losses['loss_gauss'] = self.w.w_splat * eta.sum() / (mask.sum() + 1e-10) 105 | 106 | loss = 0 107 | for k, v in losses.items(): 108 | loss += losses[k] 109 | losses['total'] = loss 110 | return losses, None 111 | -------------------------------------------------------------------------------- /models/math_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def _rot_matrix_from_angle(t): 5 | _prefix = t.shape 6 | rot_mat = torch.zeros(*_prefix, 2,2).type_as(t) 7 | rot_mat[...,0,0] = torch.cos(t) 8 | rot_mat[...,1,1] = torch.cos(t) 9 | rot_mat[...,0,1] = -torch.sin(t) 10 | rot_mat[...,1,0] = torch.sin(t) 11 | return rot_mat 12 | 13 | def two_order_real_symmetric_matrix_svd(mat): 14 | """Closed-form of 2x2 symmetric real matrix svd 15 | [[a,b], 16 | [b,d]], see https://www.researchgate.net/publication/263580188_Closed_Form_SVD_Solutions_for_2_x_2_Matrices_-_Rev_2 17 | 18 | Args: 19 | mat: (...,2,2) 20 | """ 21 | a, b, d = mat[...,0,0], mat[...,0,1], mat[...,1,1] 22 | sigmas = torch.stack([((a+d).abs() + ((a-d)**2 + 4*(b**2) + 1e-8).sqrt())/2, 23 | ((a+d).abs() - ((a-d)**2 + 4*(b**2) + 1e-8).sqrt()).abs()/2], dim=-1) # (...,2) 24 | det_idx = a*d - b**2 < 0 25 | Sigma = torch.eye(2,2).type_as(sigmas).expand_as(mat).clone() 26 | if not det_idx.any(): 27 | Sigma[det_idx][:,1,1] = -1 28 | D = torch.diag_embed(sigmas) # (...,2,2) 29 | Sigma *= torch.sign(a+d)[...,None,None] 30 | # NOTE: atan may cause nan if divisor approaching ZERO 31 | theta = torch.atan((-torch.sign(b)*(a-d) + torch.sign(a+d)*torch.sign(b)*((a-d)**2 + 4*b**2 + 1e-7).sqrt())/(2*b.abs() + 1e-7)) 32 | rot_mat = _rot_matrix_from_angle(theta) 33 | # print(f'error:{(rot_mat @ D @ Sigma @ rot_mat.transpose(-2,-1) - mat).abs().sum()}') 34 | return (D @ Sigma).diagonal(dim1=-2, dim2=-1), rot_mat -------------------------------------------------------------------------------- /render_view.sh: -------------------------------------------------------------------------------- 1 | PREFIX=$1 2 | CAMERA=$2 3 | cd ~/neurecon 4 | python -m tools.render_view --downscale 4 --config ${PREFIX}/config.yaml \ 5 | --load_pt ${PREFIX}/ckpts/00100000.pt \ 6 | --camera_path ${CAMERA} \ 7 | --num_views 360 \ 8 | --device_ids 0,1,2,3 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | addict==2.4.0 3 | ansi2html==1.8.0 4 | asttokens==2.2.1 5 | attrs==23.1.0 6 | backcall==0.2.0 7 | cachetools==5.3.1 8 | certifi==2023.7.22 9 | charset-normalizer==3.2.0 10 | click==8.1.7 11 | colorama==0.4.6 12 | comm==0.1.4 13 | ConfigArgParse==1.7 14 | contourpy==1.1.0 15 | cycler==0.11.0 16 | dash==2.12.1 17 | dash-core-components==2.0.0 18 | dash-html-components==2.0.0 19 | dash-table==5.0.0 20 | dearpygui==1.9.1 21 | decorator==5.1.1 22 | executing==1.2.0 23 | fastjsonschema==2.18.0 24 | filelock==3.12.4 25 | Flask==2.2.5 26 | fonttools==4.42.1 27 | fsspec==2023.10.0 28 | fvcore==0.1.5.post20221221 29 | google-auth==2.22.0 30 | google-auth-oauthlib==1.0.0 31 | grpcio==1.57.0 32 | h5py==3.9.0 33 | huggingface-hub==0.18.0 34 | icecream==2.1.3 35 | idna==3.4 36 | imageio==2.19.3 37 | imageio-ffmpeg==0.4.7 38 | importlib-metadata==6.8.0 39 | importlib-resources==6.0.1 40 | iopath==0.1.10 41 | ipython==8.12.2 42 | ipywidgets==8.1.0 43 | itsdangerous==2.1.2 44 | jedi==0.19.0 45 | Jinja2==3.1.2 46 | joblib==1.3.2 47 | jsonschema==4.19.0 48 | jsonschema-specifications==2023.7.1 49 | jupyter_core==5.3.1 50 | jupyterlab-widgets==3.0.8 51 | kiwisolver==1.4.5 52 | lazy_loader==0.3 53 | lpips==0.1.4 54 | Markdown==3.4.4 55 | markdown-it-py==3.0.0 56 | MarkupSafe==2.1.3 57 | matplotlib==3.7.1 58 | matplotlib-inline==0.1.6 59 | mdurl==0.1.2 60 | nbformat==5.7.0 61 | nerfacc==0.3.3 62 | nest-asyncio==1.5.7 63 | networkx==3.1 64 | ninja==1.11.1 65 | numpy==1.23.1 66 | oauthlib==3.2.2 67 | open3d==0.17.0 68 | opencv-python==4.8.0.76 69 | packaging==23.1 70 | pandas==2.0.3 71 | parso==0.8.3 72 | pexpect==4.8.0 73 | pickleshare==0.7.5 74 | Pillow==10.0.0 75 | pkgutil_resolve_name==1.3.10 76 | platformdirs==3.10.0 77 | plotly==5.16.1 78 | plyfile==1.0.1 79 | portalocker==2.7.0 80 | prettytable==3.8.0 81 | prompt-toolkit==3.0.39 82 | protobuf==4.24.1 83 | ptyprocess==0.7.0 84 | pure-eval==0.2.2 85 | pyasn1==0.5.0 86 | pyasn1-modules==0.3.0 87 | pybind11==2.11.1 88 | Pygments==2.16.1 89 | pyhocon==0.3.60 90 | PyMCubes==0.1.4 91 | pyparsing==3.0.9 92 | pyquaternion==0.9.9 93 | python-dateutil==2.8.2 94 | pytz==2023.3 95 | PyWavelets==1.4.1 96 | PyYAML==6.0.1 97 | referencing==0.30.2 98 | requests==2.31.0 99 | requests-oauthlib==1.3.1 100 | retrying==1.3.4 101 | rich==13.5.2 102 | rpds-py==0.9.2 103 | rsa==4.9 104 | safetensors==0.4.0 105 | scikit-image==0.21.0 106 | scikit-learn==1.3.0 107 | scipy==1.10.1 108 | six==1.16.0 109 | stack-data==0.6.2 110 | tabulate==0.9.0 111 | tenacity==8.2.3 112 | tensorboard==2.14.0 113 | tensorboard-data-server==0.7.1 114 | tensorboardX==2.6.2.2 115 | termcolor==2.3.0 116 | threadpoolctl==3.2.0 117 | tifffile==2023.7.10 118 | timm==0.9.8 119 | torch==1.13.1+cu116 120 | torch-tb-profiler==0.4.1 121 | torchaudio==0.13.1+cu116 122 | torchvision==0.14.1+cu116 123 | tqdm==4.66.1 124 | traitlets==5.9.0 125 | transforms3d==0.4.1 126 | trimesh==3.22.5 127 | typing_extensions==4.7.1 128 | tzdata==2023.3 129 | urllib3==1.26.16 130 | wcwidth==0.2.6 131 | Werkzeug==2.2.3 132 | widgetsnbextension==4.0.8 133 | yacs==0.1.8 134 | zipp==3.16.2 -------------------------------------------------------------------------------- /sdf2mesh.py: -------------------------------------------------------------------------------- 1 | import torch, os 2 | import trimesh 3 | import mcubes 4 | import argparse 5 | import json 6 | import numpy as np 7 | 8 | from utils.print_fn import log 9 | import utils.general as utils 10 | import utils.plots as plt 11 | from utils import rend_util 12 | from utils.io_util import load_yaml 13 | from models.frameworks import get_model 14 | 15 | def scale_anything(dat, inp_scale, tgt_scale): 16 | if inp_scale is None: 17 | inp_scale = [dat.min(), dat.max()] 18 | dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) 19 | dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] 20 | return dat 21 | 22 | def extract_fields(bound_min, bound_max, resolution, query_func): 23 | N = 64 24 | X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N) 25 | Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N) 26 | Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N) 27 | u = np.zeros([resolution, resolution, resolution], dtype=np.float32) 28 | with torch.no_grad(): 29 | for xi, xs in enumerate(X): 30 | for yi, ys in enumerate(Y): 31 | for zi, zs in enumerate(Z): 32 | xx, yy, zz = torch.meshgrid(xs, ys, zs) 33 | pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1).cuda() 34 | val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() 35 | u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val 36 | return u 37 | 38 | def extract_geometry_(bound_min, bound_max, resolution, threshold, query_func): 39 | log.info('Threshold: {}'.format(threshold)) 40 | sdfs = extract_fields(bound_min, bound_max, resolution, query_func) 41 | log.info('Marching Cubes') 42 | vertices, triangles = mcubes.marching_cubes(sdfs, threshold) 43 | b_max_np = bound_max.detach().cpu().numpy() 44 | b_min_np = bound_min.detach().cpu().numpy() 45 | 46 | vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] 47 | return vertices, triangles 48 | 49 | def extract_geometry(model, bound_min, bound_max, resolution, threshold=0.0): 50 | """Coarse-to-Fine Mesh Extraction 51 | 52 | Args: 53 | model: neural sdf model 54 | bound_min: aabb 55 | bound_max: aabb 56 | resolution: grid vertices 57 | threshold: level set. Defaults to 0.0. 58 | 59 | Returns: 60 | mesh 61 | """ 62 | 63 | return extract_geometry_(bound_min, 64 | bound_max, 65 | resolution=resolution, 66 | threshold=threshold, 67 | query_func=lambda pts: -model.forward_surface(pts)) 68 | 69 | def validate_mesh(model, epoch, args, world_space=False, resolution=64, threshold=0.0): 70 | with open(args.json, 'r') as f: 71 | surface_configs = json.load(f) 72 | bound_min = torch.tensor(surface_configs['bbox_min'], dtype=torch.float32) 73 | bound_max = torch.tensor(surface_configs['bbox_max'], dtype=torch.float32) 74 | log.info('Coarse Bounding Box:') 75 | log.info([bound_min.numpy().tolist(), bound_max.numpy().tolist()]) 76 | vertices, triangles =\ 77 | extract_geometry(model, bound_min, bound_max, resolution=resolution, threshold=threshold) 78 | vertices = torch.from_numpy(vertices) 79 | v_min, v_max = vertices.amin(dim=0), vertices.amax(dim=0) 80 | vmin_ = (v_min - (v_max - v_min) * 0.1).clamp(bound_min , bound_max) 81 | vmax_ = (v_max + (v_max - v_min) * 0.1).clamp(bound_min, bound_max) 82 | log.info('Fine Bounding Box:') 83 | log.info([vmin_.numpy().tolist(), vmax_.numpy().tolist()]) 84 | vertices, triangles =\ 85 | extract_geometry(model, vmin_, vmax_, resolution=resolution, threshold=threshold) 86 | evals_folder_name = surface_configs['eval'] 87 | os.makedirs(evals_folder_name, exist_ok=True) 88 | 89 | # if world_space: 90 | # vertices = vertices * self.dataset.scale_mats_np[0][0, 0] + self.dataset.scale_mats_np[0][:3, 3][None] 91 | 92 | mesh = trimesh.Trimesh(vertices, triangles) 93 | mesh_save_path =os.path.join(evals_folder_name, 'N_{:0>8d}.ply'.format(epoch)) 94 | mesh.export(mesh_save_path) 95 | log.info(f'Mesh saved in {mesh_save_path}') 96 | log.info('End') 97 | 98 | def model_wrapper(p): 99 | with open(p, 'r') as f: 100 | surface_configs = json.load(f) 101 | log.info(f'Surface config loaded from {p}') 102 | evals_folder_name = surface_configs['eval'] 103 | exps_folder_name = surface_configs['exp'] 104 | utils.mkdir_ifnotexists(os.path.join('./', evals_folder_name)) 105 | expdir = os.path.join('./', exps_folder_name) 106 | evaldir = os.path.join('./', evals_folder_name) 107 | utils.mkdir_ifnotexists(evaldir) 108 | iter = surface_configs['iteration'] 109 | 110 | args = load_yaml(f'{expdir}/config.yaml') 111 | args.device_ids = [0] 112 | model, _, render_kwargs_train, render_kwargs_test, volume_render_fn = get_model(args) 113 | 114 | if torch.cuda.is_available(): 115 | log.info('Cuda Detected') 116 | model.cuda() 117 | checkpoint_path = f'{expdir}/ckpts/{iter}.pt' 118 | 119 | # saved_model_state = torch.load(os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth")) 120 | saved_model_state = torch.load(checkpoint_path) 121 | # print(saved_model_state.keys()) 122 | 123 | model.load_state_dict(saved_model_state["model"]) 124 | epoch = saved_model_state['global_step'] 125 | 126 | return model, epoch 127 | 128 | 129 | if __name__ == '__main__': 130 | 131 | parser = argparse.ArgumentParser() 132 | parser.add_argument('--resolution', default=512, type=int, help='Grid resolution for marching cube') 133 | parser.add_argument('--json', type=str, default='surface.json', help='Surface Configs.') 134 | opt = parser.parse_args() 135 | 136 | model, epoch = model_wrapper(opt.json) 137 | validate_mesh(model, epoch, opt, world_space=False, resolution=opt.resolution, threshold=0.0) -------------------------------------------------------------------------------- /sdf2msh_volsdf.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse, json 3 | import GPUtil 4 | import os 5 | from utils.print_fn import log 6 | 7 | # from pyhocon import ConfigFactory 8 | import torch 9 | import numpy as np 10 | from PIL import Image 11 | from tqdm import tqdm 12 | import pandas as pd 13 | 14 | import utils.general as utils 15 | import utils.plots as plt 16 | from utils import rend_util 17 | from utils.io_util import load_yaml 18 | from models.frameworks import get_model 19 | 20 | def evaluate(**kwargs): 21 | torch.set_default_dtype(torch.float32) 22 | torch.set_num_threads(1) 23 | 24 | # conf = ConfigFactory.parse_file(kwargs['conf']) 25 | exps_folder_name = kwargs['exps_folder_name'] 26 | evals_folder_name = kwargs['evals_folder_name'] 27 | # eval_rendering = kwargs['eval_rendering'] 28 | 29 | 30 | # expname = conf.get_string('train.expname') +'_'+ kwargs['expname'] 31 | 32 | ''' 33 | scan_id = kwargs['scan_id'] if kwargs['scan_id'] != -1 else conf.get_int('dataset.scan_id', default=-1) 34 | if scan_id != -1: 35 | expname = expname + '_{0}'.format(scan_id) 36 | else: 37 | scan_id = conf.get_string('dataset.object', default='') 38 | ''' 39 | 40 | # scan_id = kwargs['scan_id'] 41 | 42 | ''' 43 | if kwargs['timestamp'] == 'latest': 44 | if os.path.exists(os.path.join('../', kwargs['exps_folder_name'], expname)): 45 | timestamps = os.listdir(os.path.join('../', kwargs['exps_folder_name'], expname)) 46 | if (len(timestamps)) == 0: 47 | print('WRONG EXP FOLDER') 48 | exit() 49 | # self.timestamp = sorted(timestamps)[-1] 50 | timestamp = None 51 | for t in sorted(timestamps): 52 | if os.path.exists(os.path.join('../', kwargs['exps_folder_name'], expname, t, 'checkpoints', 53 | 'ModelParameters', str(kwargs['checkpoint']) + ".pth")): 54 | timestamp = t 55 | if timestamp is None: 56 | print('NO GOOD TIMSTAMP') 57 | exit() 58 | else: 59 | print('WRONG EXP FOLDER') 60 | exit() 61 | else: 62 | timestamp = kwargs['timestamp'] 63 | ''' 64 | 65 | utils.mkdir_ifnotexists(os.path.join('./', evals_folder_name)) 66 | expdir = os.path.join('./', exps_folder_name) 67 | evaldir = os.path.join('./', evals_folder_name) 68 | utils.mkdir_ifnotexists(evaldir) 69 | 70 | # dataset_conf = conf.get_config('dataset') 71 | ''' 72 | if kwargs['scan_id'] != -1: 73 | dataset_conf['scan_id'] = kwargs['scan_id'] 74 | eval_dataset = utils.get_class(conf.get_string('train.dataset_class'))(**dataset_conf) 75 | ''' 76 | args = load_yaml(f'{expdir}/config.yaml') 77 | args.device_ids = [0] 78 | model, _, render_kwargs_train, render_kwargs_test, volume_render_fn = get_model(args) 79 | 80 | if torch.cuda.is_available(): 81 | model.cuda() 82 | 83 | # settings for camera optimization 84 | # scale_mat = eval_dataset.get_scale_mat() 85 | 86 | ''' 87 | if eval_rendering: 88 | eval_dataloader = torch.utils.data.DataLoader(eval_dataset, 89 | batch_size=1, 90 | shuffle=False, 91 | collate_fn=eval_dataset.collate_fn 92 | ) 93 | total_pixels = eval_dataset.total_pixels 94 | img_res = eval_dataset.img_res 95 | split_n_pixels = conf.get_int('train.split_n_pixels', 10000) 96 | ''' 97 | 98 | # old_checkpnts_dir = os.path.join(expdir, timestamp, 'checkpoints') 99 | 100 | # checkpoint_path = f'{expdir}/ckpts/00100000.pt' 101 | checkpoint_path = f'{expdir}/ckpts/00100000.pt' 102 | # saved_model_state = torch.load(os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth")) 103 | saved_model_state = torch.load(checkpoint_path) 104 | print(saved_model_state.keys()) 105 | 106 | model.load_state_dict(saved_model_state["model"]) 107 | epoch = saved_model_state['global_step'] 108 | 109 | #################################################################################################################### 110 | print("evaluating...") 111 | 112 | model.eval() 113 | 114 | with torch.no_grad(): 115 | 116 | ''' 117 | if scan_id < 24: # Blended MVS 118 | mesh = plt.get_surface_high_res_mesh( 119 | sdf=lambda x: model.implicit_network(x)[:, 0], 120 | resolution=kwargs['resolution'], 121 | grid_boundary=conf.get_list('plot.grid_boundary'), 122 | level=conf.get_int('plot.level', default=0), 123 | take_components = type(scan_id) is not str 124 | ) 125 | else: # DTU 126 | bb_dict = np.load('../data/DTU/bbs.npz') 127 | grid_params = bb_dict[str(scan_id)] 128 | 129 | mesh = plt.get_surface_by_grid( 130 | grid_params=grid_params, 131 | sdf=lambda x: model.implicit_network(x)[:, 0], 132 | resolution=kwargs['resolution'], 133 | level=conf.get_int('plot.level', default=0), 134 | higher_res=True 135 | ) 136 | ''' 137 | 138 | ## KK 139 | ## grid_params.shape = (2, 3) 140 | ## 3d bbox:定义了要做marching cube的区域 141 | 142 | # bb_dict = np.load('../data/DTU/bbs.npz') 143 | # grid_params = bb_dict[str(37)] 144 | # grid_params = 10 * np.ones((2, 3)) 145 | # grid_params[0, :] = -1 * grid_params[0, :] 146 | 147 | grid_params = np.array([ 148 | [-0.5, -0.5, 0], 149 | [0.5, 0.5, 1.0] 150 | ]) 151 | 152 | # grid_params = np.array([ 153 | # [-0.3, -0.3, -0.5], 154 | # [0.3, 0.3, 0.1] 155 | # ]) 156 | # grid_params = grid_params + np.array([0.1, 0.1, 0.4]) 157 | print(grid_params) 158 | 159 | 160 | # import IPython; IPython.embed(); exit() 161 | 162 | print("Extracting mesh...") 163 | mesh = plt.get_surface_by_grid( 164 | grid_params=grid_params, 165 | sdf=lambda x: model.forward_surface(x), 166 | resolution=kwargs['resolution'], 167 | level=0, 168 | # higher_res=True 169 | higher_res=False 170 | ) 171 | 172 | print("Almost done...") 173 | 174 | # Transform to world coordinates 175 | # mesh.apply_transform(scale_mat) 176 | 177 | # Taking the biggest connected component 178 | components = mesh.split(only_watertight=False) 179 | areas = np.array([c.area for c in components], dtype=np.float32) 180 | mesh_clean = components[areas.argmax()] 181 | 182 | mesh_folder = '{0}'.format(evaldir) 183 | utils.mkdir_ifnotexists(mesh_folder) 184 | 185 | # mesh_clean.export('{0}/scan{1}.ply'.format(mesh_folder, scan_id), 'ply') 186 | mesh_clean.export('{0}/scan{1}_{2}.ply'.format(mesh_folder, 'test', epoch), 'ply') 187 | 188 | ''' 189 | if eval_rendering: 190 | images_dir = '{0}/rendering_{1}'.format(evaldir, epoch) 191 | utils.mkdir_ifnotexists(images_dir) 192 | 193 | psnrs = [] 194 | for data_index, (indices, model_input, ground_truth) in enumerate(eval_dataloader): 195 | model_input["intrinsics"] = model_input["intrinsics"].cuda() 196 | model_input["uv"] = model_input["uv"].cuda() 197 | model_input['pose'] = model_input['pose'].cuda() 198 | 199 | split = utils.split_input(model_input, total_pixels, n_pixels=split_n_pixels) 200 | res = [] 201 | for s in tqdm(split): 202 | torch.cuda.empty_cache() 203 | out = model(s) 204 | res.append({ 205 | 'rgb_values': out['rgb_values'].detach(), 206 | }) 207 | 208 | batch_size = ground_truth['rgb'].shape[0] 209 | model_outputs = utils.merge_output(res, total_pixels, batch_size) 210 | rgb_eval = model_outputs['rgb_values'] 211 | rgb_eval = rgb_eval.reshape(batch_size, total_pixels, 3) 212 | 213 | rgb_eval = plt.lin2img(rgb_eval, img_res).detach().cpu().numpy()[0] 214 | rgb_eval = rgb_eval.transpose(1, 2, 0) 215 | img = Image.fromarray((rgb_eval * 255).astype(np.uint8)) 216 | img.save('{0}/eval_{1}.png'.format(images_dir,'%03d' % indices[0])) 217 | 218 | psnr = rend_util.get_psnr(model_outputs['rgb_values'], 219 | ground_truth['rgb'].cuda().reshape(-1, 3)).item() 220 | psnrs.append(psnr) 221 | 222 | 223 | psnrs = np.array(psnrs).astype(np.float64) 224 | print("RENDERING EVALUATION {2}: psnr mean = {0} ; psnr std = {1}".format("%.2f" % psnrs.mean(), "%.2f" % psnrs.std(), scan_id)) 225 | psnrs = np.concatenate([psnrs, psnrs.mean()[None], psnrs.std()[None]]) 226 | pd.DataFrame(psnrs).to_csv('{0}/psnr_{1}.csv'.format(evaldir, epoch)) 227 | ''' 228 | 229 | 230 | if __name__ == '__main__': 231 | 232 | parser = argparse.ArgumentParser() 233 | 234 | parser.add_argument('--conf', type=str, default='./confs/dtu.conf') 235 | parser.add_argument('--expname', type=str, default='', help='The experiment name to be evaluated.') 236 | parser.add_argument('--exps_folder', type=str, default='exps', help='The experiments folder name.') 237 | parser.add_argument('--evals_folder', type=str, default='evals', help='The evaluation folder name.') 238 | parser.add_argument('--gpu', type=str, default='auto', help='GPU to use [default: GPU auto]') 239 | parser.add_argument('--timestamp', default='latest', type=str, help='The experiemnt timestamp to test.') 240 | parser.add_argument('--checkpoint', default='latest',type=str,help='The trained model checkpoint to test') 241 | parser.add_argument('--scan_id', type=int, default=-1, help='If set, taken to be the scan id.') 242 | parser.add_argument('--resolution', default=1024, type=int, help='Grid resolution for marching cube') 243 | parser.add_argument('--eval_rendering', default=False, action="store_true", help='If set, evaluate rendering quality.') 244 | parser.add_argument('--json', type=str, default='surface.json', help='Surface Configs.') 245 | 246 | opt = parser.parse_args() 247 | 248 | if opt.gpu == "auto": 249 | deviceIDs = GPUtil.getAvailable(order='memory', limit=1, maxLoad=0.5, maxMemory=0.5, includeNan=False, excludeID=[], excludeUUID=[]) 250 | gpu = deviceIDs[0] 251 | else: 252 | gpu = opt.gpu 253 | 254 | if (not gpu == 'ignore'): 255 | os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(gpu) 256 | 257 | evaluate(conf=opt.conf, 258 | expname=opt.expname, 259 | timestamp=opt.timestamp, 260 | checkpoint=opt.checkpoint, 261 | scan_id=opt.scan_id, 262 | resolution=opt.resolution, 263 | eval_rendering=opt.eval_rendering, 264 | exps_folder_name=opt.exps_folder, 265 | evals_folder_name=opt.evals_folder, 266 | ) 267 | -------------------------------------------------------------------------------- /surface.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp": "logs/pnr/pmivr_camera_rgb_1.0", 3 | "eval": "logs/pnr/pmivr_camera_rgb_1.0", 4 | "bbox_min": [ 5 | -0.2, 6 | 0, 7 | -0.04 8 | ], 9 | "bbox_max":[ 10 | 0.3, 11 | 1.0, 12 | 0.5 13 | ], 14 | "iteration": "latest" 15 | } -------------------------------------------------------------------------------- /tools/360cameraPath/camera_intrinsics.json: -------------------------------------------------------------------------------- 1 | { 2 | "intrinsics": [ 3 | [3707.8212890625, 0, 1223.5], 4 | [0, 3707.8212890625, 1023.5], 5 | [0, 0, 1] 6 | ] 7 | } -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/tools/__init__.py -------------------------------------------------------------------------------- /tools/azi2aop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import imageio,json, cv2 4 | import dataio.polanalyser as pa 5 | 6 | def normal_to_aop(pose, 7 | normal_map): 8 | '''From normals to predicted aop 9 | 10 | Return: 11 | [B, N_rays] 12 | ''' 13 | 14 | w2c = pose[:, :3,:3].transpose(1, 2) # R^T, B x 3 x 3 15 | # check_np(w2c) 16 | 17 | (N_batch, N_rays, _) = normal_map.shape 18 | 19 | # normal_map = normal_map.reshape([N_batch, 3, N_samples]) 20 | normal_map = normal_map.transpose(1, 2) # [B, 3, N_rays] 21 | normal_map_cam = torch.bmm(w2c, normal_map) # B x 3 x 3 @ B x 3 x N_rays = B x 3 x N_rays 22 | normal_map_cam = normal_map_cam.transpose(1, 2) # B x N_rays x 3 23 | phi = torch.atan2(normal_map_cam[...,1], normal_map_cam[...,0]) # N_batch x N_rays (rad) 24 | 25 | # MOD: PMIVR Loss Deprecated 26 | # eta = torch.stack([torch.abs(phi-AoP_gt-np.pi/2), torch.abs(phi-AoP_gt), torch.abs(phi-AoP_gt+np.pi/2), torch.abs(phi-AoP_gt+np.pi)], dim=1) 27 | # eta, _ = torch.min(eta, dim=1) 28 | 29 | # phi_to_aop = np.pi/2 - phi 30 | phi_to_aop = phi 31 | # mod to [0, pi] 32 | phi_to_aop = torch.remainder(phi_to_aop, np.pi) 33 | 34 | return phi_to_aop 35 | 36 | def get_pose(cameraJson, idx): 37 | rotation = torch.Tensor(cameraJson[idx]['rotation']).transpose(1, 0).float() # R^T 38 | translation = torch.Tensor(cameraJson[idx]['camera_pos']).float() # C = -R_transpose*t 39 | c2w_ = torch.cat([rotation, translation.unsqueeze(1)], dim=1) # 3 x 4 40 | c2w = torch.cat([c2w_, torch.Tensor([[0.,0.,0.,1.]])], dim=0) 41 | return c2w 42 | 43 | if __name__ == '__main__': 44 | normal = imageio.imread('normal.png')/255 45 | (H, W, _) = normal.shape 46 | normal_ori = (normal - 0.5)*2 47 | normal_ori = torch.from_numpy(normal_ori[None,...]).flatten(1,2).float() 48 | with open('/camera_extrinsics.json', 'r') as f: 49 | camera_extrinsics = json.load(f) 50 | c2w = get_pose(cameraJson=camera_extrinsics,idx='23.png')[None,...] # Batched 51 | aop = normal_to_aop(c2w, normal_ori).numpy().reshape((H, W)).squeeze() 52 | print(aop.max(),aop.min()) 53 | aop_img = pa.applyColorToAoLP(aop) 54 | cv2.imwrite('pred_aop_ori.png', aop_img) -------------------------------------------------------------------------------- /tools/eval_3dprint.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | import sklearn.neighbors as skln 4 | from tqdm import tqdm 5 | from scipy.io import loadmat 6 | import multiprocessing as mp 7 | import trimesh, os 8 | 9 | def sample_single_tri(input_): 10 | n1, n2, v1, v2, tri_vert = input_ 11 | c = np.mgrid[:n1 + 1, :n2 + 1] 12 | c += 0.5 13 | c[0] /= max(n1, 1e-7) 14 | c[1] /= max(n2, 1e-7) 15 | c = np.transpose(c, (1, 2, 0)) 16 | k = c[c.sum(axis=-1) < 1] # m2 17 | q = v1 * k[:, :1] + v2 * k[:, 1:] + tri_vert 18 | return q 19 | 20 | def write_vis_pcd(file, points, colors): 21 | pcd = o3d.geometry.PointCloud() 22 | pcd.points = o3d.utility.Vector3dVector(points) 23 | pcd.colors = o3d.utility.Vector3dVector(colors) 24 | o3d.io.write_point_cloud(file, pcd) 25 | 26 | def evaluation_3d_print(data_path, dataset_dir, vis_out_dir, downsample_density=0.001, patch_size=60, max_dist_d=100, 27 | max_dist_t=10, visualize_threshold=10, points_for_plane=None, nonvalid_bbox=None,z_min=None): 28 | mp.freeze_support() 29 | data = o3d.io.read_triangle_mesh(data_path) 30 | f=open(os.path.join(os.path.dirname(dataset_dir), f'CD.txt'), 'a') 31 | thresh = downsample_density 32 | 33 | pbar = tqdm(total=9) 34 | pbar.set_description('read data mesh') 35 | data_mesh = data 36 | 37 | vertices = np.asarray(data_mesh.vertices) 38 | triangles = np.asarray(data_mesh.triangles) 39 | tri_vert = vertices[triangles] 40 | 41 | pbar.update(1) 42 | pbar.set_description('sample pcd from mesh') 43 | 44 | v1 = tri_vert[:,1] - tri_vert[:,0] 45 | v2 = tri_vert[:,2] - tri_vert[:,0] 46 | l1 = np.linalg.norm(v1, axis=-1, keepdims=True) 47 | l2 = np.linalg.norm(v2, axis=-1, keepdims=True) 48 | area2 = np.linalg.norm(np.cross(v1, v2), axis=-1, keepdims=True) 49 | non_zero_area = (area2 > 0)[:,0] 50 | l1, l2, area2, v1, v2, tri_vert = [ 51 | arr[non_zero_area] for arr in [l1, l2, area2, v1, v2, tri_vert]] 52 | thr = thresh * np.sqrt(l1 * l2 / area2) 53 | n1 = np.floor(l1 / thr) 54 | n2 = np.floor(l2 / thr) 55 | 56 | with mp.Pool() as mp_pool: 57 | new_pts = mp_pool.map(sample_single_tri, ((n1[i,0], n2[i,0], v1[i:i+1], v2[i:i+1], tri_vert[i:i+1,0]) for i in range(len(n1))), chunksize=1024) 58 | 59 | new_pts = np.concatenate(new_pts, axis=0) 60 | data_pcd = np.concatenate([vertices, new_pts], axis=0) 61 | 62 | # # # save dense point cloud 63 | # PCD = o3d.geometry.PointCloud() 64 | # PCD.points = o3d.utility.Vector3dVector(data_pcd) 65 | # o3d.io.write_point_cloud('/newdata/wenhangge/data/refnerf/toaster/dense_pcd.ply' ,PCD) 66 | 67 | 68 | pbar.update(1) 69 | pbar.set_description('random shuffle pcd index') 70 | shuffle_rng = np.random.default_rng() 71 | shuffle_rng.shuffle(data_pcd, axis=0) 72 | 73 | pbar.update(1) 74 | pbar.set_description('downsample pcd') 75 | nn_engine = skln.NearestNeighbors(n_neighbors=1, radius=thresh, algorithm='kd_tree', n_jobs=-1) 76 | nn_engine.fit(data_pcd) 77 | rnn_idxs = nn_engine.radius_neighbors(data_pcd, radius=thresh, return_distance=False) 78 | mask = np.ones(data_pcd.shape[0], dtype=np.bool_) 79 | for curr, idxs in enumerate(rnn_idxs): 80 | if mask[curr]: 81 | mask[idxs] = 0 82 | mask[curr] = 1 83 | data_down = data_pcd[mask] 84 | 85 | pbar.update(1) 86 | pbar.set_description('read STL pcd') 87 | stl_pcd = o3d.io.read_point_cloud(dataset_dir) 88 | stl = np.asarray(stl_pcd.points) 89 | # BB = np.array([stl.min(0), stl.max(0)]) 90 | BB = np.array([vertices.min(0), vertices.max(0)]) 91 | print(BB) 92 | # compute lowest surface 93 | if points_for_plane is not None: 94 | p1 = np.array(points_for_plane[0]) 95 | p2 = np.array(points_for_plane[1]) 96 | p3 = np.array(points_for_plane[2]) 97 | else: 98 | z_min = BB[0,2] if z_min is None else z_min 99 | points_for_plane = [np.array([0,0,z_min]),np.array([0,1,z_min]),np.array([1,0,z_min])] 100 | p1 = np.array(points_for_plane[0]) 101 | p2 = np.array(points_for_plane[1]) 102 | p3 = np.array(points_for_plane[2]) 103 | v1 = p1 - p2 104 | v2 = p3 - p2 105 | normal = np.cross(v1, v2) 106 | # make sure the normal toward positive z 107 | if normal[-1] < 0: 108 | normal = np.cross(v2, v1) 109 | D = np.dot(normal, p1) 110 | 111 | pbar.update(1) 112 | pbar.set_description('masking data pcd') 113 | 114 | BB = BB.astype(np.float32) 115 | 116 | patch = patch_size 117 | inbound = ((data_down >= BB[:1]-patch) & (data_down < BB[1:]+patch*2)).sum(axis=-1) ==3 118 | data_in = data_down[inbound] 119 | 120 | above = (data_in @ normal - D) > 0 121 | data_in_above = data_in[above] 122 | 123 | above_stl = (stl @ normal - D) > 0 124 | stl_above = stl[above_stl] 125 | 126 | if nonvalid_bbox is not None: 127 | aa = nonvalid_bbox[0] 128 | bb = nonvalid_bbox[1] 129 | 130 | mask_bbox = ((data_in_above >= bb) & (data_in_above <= aa)).sum(axis=-1) ==3 131 | mask_val = ~mask_bbox 132 | else: 133 | mask_val = np.ones_like(data_in_above) 134 | mask_val = mask_val.astype(bool)[:, 0] 135 | data_in_above = data_in_above[mask_val] 136 | 137 | pbar.update(1) 138 | pbar.set_description('compute data2stl') 139 | nn_engine.fit(stl) 140 | dist_d2s, idx_d2s = nn_engine.kneighbors(data_in_above, n_neighbors=1, return_distance=True) 141 | mean_d2s = dist_d2s[dist_d2s < max_dist_d].mean() 142 | 143 | pbar.update(1) 144 | pbar.set_description('compute stl2data') 145 | nn_engine.fit(data_in) 146 | dist_s2d, idx_s2d = nn_engine.kneighbors(stl_above, n_neighbors=1, return_distance=True) 147 | mean_s2d = dist_s2d[dist_s2d < max_dist_t].mean() 148 | 149 | pbar.update(1) 150 | pbar.set_description('visualize error') 151 | vis_dist = visualize_threshold 152 | R = np.array([[1,0,0]], dtype=np.float64) 153 | G = np.array([[0,1,0]], dtype=np.float64) 154 | B = np.array([[0,0,1]], dtype=np.float64) 155 | W = np.array([[1,1,1]], dtype=np.float64) 156 | data_color = np.tile(B, (data_down.shape[0], 1)) 157 | data_alpha = dist_d2s.clip(max=vis_dist) / vis_dist 158 | 159 | data_color[ np.where(inbound)[0][above][mask_val] ] = R * data_alpha + W * (1-data_alpha) 160 | data_color[ np.where(inbound)[0][above][mask_val] [dist_d2s[:,0] >= max_dist_d] ] = G 161 | os.makedirs(vis_out_dir, exist_ok=True) 162 | write_vis_pcd(f'{vis_out_dir}/vis_d2s.ply', data_down, data_color) 163 | 164 | stl_color = np.tile(B, (stl.shape[0], 1)) 165 | stl_alpha = dist_s2d.clip(max=vis_dist) / vis_dist 166 | stl_color[ np.where(above_stl)[0] ] = R * stl_alpha + W * (1-stl_alpha) 167 | stl_color[ np.where(above_stl)[0][dist_s2d[:,0] >= max_dist_t] ] = G 168 | write_vis_pcd(f'{vis_out_dir}/vis_s2d.ply', stl, stl_color) 169 | 170 | pbar.update(1) 171 | pbar.set_description('done') 172 | pbar.close() 173 | over_all = (mean_d2s + mean_s2d) / 2 174 | 175 | print(mean_d2s, mean_s2d, over_all) 176 | f.write(str(data_path) + '_CD: ') 177 | f.write(str(mean_d2s) + ' ') 178 | f.write(str(mean_s2d) + ' ') 179 | f.write(str(over_all) + ' ') 180 | f.write('\n') 181 | f.flush() 182 | f.close() 183 | return mean_d2s, mean_s2d, over_all 184 | 185 | if __name__ == '__main__': 186 | data_path = f'eval/Duck/duck@rewmean.ply' 187 | evaluation_3d_print(data_path, f'eval/Duck/Duck.ply', f'eval/dragon/nero',downsample_density=0.1) -------------------------------------------------------------------------------- /tools/extract_surface.py: -------------------------------------------------------------------------------- 1 | from models.base import ImplicitSurface 2 | from utils.mesh_util import extract_mesh 3 | 4 | import torch 5 | 6 | def main_function(args): 7 | N = args.N 8 | s = args.volume_size 9 | implicit_surface = ImplicitSurface(radius_init=args.init_r).cuda() 10 | if args.load_pt is not None: 11 | # --------- if load statedict 12 | # state_dict = torch.load("/home/PJLAB/guojianfei/latest.pt") 13 | # state_dict = torch.load("./dev_test/37/latest.pt") 14 | state_dict = torch.load(args.load_pt) 15 | imp_surface_state_dict = {k.replace('implicit_surface.',''):v for k, v in state_dict['model'].items() if 'implicit_surface.' in k} 16 | imp_surface_state_dict['obj_bounding_size'] = torch.tensor([1.0]).cuda() 17 | implicit_surface.load_state_dict(imp_surface_state_dict) 18 | if args.out is None: 19 | from datetime import datetime 20 | dt = datetime.now() 21 | args.out = 'surface_' + dt.strftime("%Y%m%d%H%M%S") + '.ply' 22 | extract_mesh(implicit_surface, s, N=N, filepath=args.out, show_progress=True, chunk=args.chunk) 23 | 24 | 25 | if __name__ == "__main__": 26 | import argparse 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("--out", type=str, default=None, help='output ply file name') 29 | parser.add_argument('--N', type=int, default=512, help='resolution of the marching cube algo') 30 | parser.add_argument('--volume_size', type=float, default=2., help='voxel size to run marching cube') 31 | parser.add_argument("--load_pt", type=str, default=None, help='the trained model checkpoint .pt file') 32 | parser.add_argument("--chunk", type=int, default=16*1024, help='net chunk when querying the network. change for smaller GPU memory.') 33 | parser.add_argument("--init_r", type=float, default=1.0, help='Optional. The init radius of the implicit surface.') 34 | args = parser.parse_args() 35 | 36 | main_function(args) -------------------------------------------------------------------------------- /tools/vis_camera.py: -------------------------------------------------------------------------------- 1 | ''' 2 | camera extrinsics visualization tools 3 | modified from https://github.com/opencv/opencv/blob/master/samples/python/camera_calibration_show_extrinsics.py 4 | ''' 5 | 6 | from utils.print_fn import log 7 | # Python 2/3 compatibility 8 | 9 | import numpy as np 10 | import cv2 as cv 11 | 12 | from numpy import linspace 13 | import matplotlib 14 | 15 | # matplotlib.use('TkAgg') 16 | 17 | def inverse_homogeneoux_matrix(M): 18 | R = M[0:3, 0:3] 19 | T = M[0:3, 3] 20 | M_inv = np.identity(4) 21 | M_inv[0:3, 0:3] = R.T 22 | M_inv[0:3, 3] = -(R.T).dot(T) 23 | 24 | return M_inv 25 | 26 | 27 | def transform_to_matplotlib_frame(cMo, X, inverse=False): 28 | M = np.identity(4) 29 | M[1, 1] = 0 30 | M[1, 2] = 1 31 | M[2, 1] = -1 32 | M[2, 2] = 0 33 | 34 | if inverse: 35 | return M.dot(inverse_homogeneoux_matrix(cMo).dot(X)) 36 | else: 37 | return M.dot(cMo.dot(X)) 38 | 39 | 40 | def create_camera_model(camera_matrix, width, height, scale_focal, draw_frame_axis=False): 41 | fx = camera_matrix[0, 0] 42 | fy = camera_matrix[1, 1] 43 | focal = 2 / (fx + fy) 44 | f_scale = scale_focal * focal 45 | 46 | # draw image plane 47 | X_img_plane = np.ones((4, 5)) 48 | X_img_plane[0:3, 0] = [-width, height, f_scale] 49 | X_img_plane[0:3, 1] = [width, height, f_scale] 50 | X_img_plane[0:3, 2] = [width, -height, f_scale] 51 | X_img_plane[0:3, 3] = [-width, -height, f_scale] 52 | X_img_plane[0:3, 4] = [-width, height, f_scale] 53 | 54 | # draw triangle above the image plane 55 | X_triangle = np.ones((4, 3)) 56 | X_triangle[0:3, 0] = [-width, -height, f_scale] 57 | X_triangle[0:3, 1] = [0, -2*height, f_scale] 58 | X_triangle[0:3, 2] = [width, -height, f_scale] 59 | 60 | # draw camera 61 | X_center1 = np.ones((4, 2)) 62 | X_center1[0:3, 0] = [0, 0, 0] 63 | X_center1[0:3, 1] = [-width, height, f_scale] 64 | 65 | X_center2 = np.ones((4, 2)) 66 | X_center2[0:3, 0] = [0, 0, 0] 67 | X_center2[0:3, 1] = [width, height, f_scale] 68 | 69 | X_center3 = np.ones((4, 2)) 70 | X_center3[0:3, 0] = [0, 0, 0] 71 | X_center3[0:3, 1] = [width, -height, f_scale] 72 | 73 | X_center4 = np.ones((4, 2)) 74 | X_center4[0:3, 0] = [0, 0, 0] 75 | X_center4[0:3, 1] = [-width, -height, f_scale] 76 | 77 | # draw camera frame axis 78 | X_frame1 = np.ones((4, 2)) 79 | X_frame1[0:3, 0] = [0, 0, 0] 80 | X_frame1[0:3, 1] = [f_scale/2, 0, 0] 81 | 82 | X_frame2 = np.ones((4, 2)) 83 | X_frame2[0:3, 0] = [0, 0, 0] 84 | X_frame2[0:3, 1] = [0, f_scale/2, 0] 85 | 86 | X_frame3 = np.ones((4, 2)) 87 | X_frame3[0:3, 0] = [0, 0, 0] 88 | X_frame3[0:3, 1] = [0, 0, f_scale/2] 89 | 90 | if draw_frame_axis: 91 | return [X_img_plane, X_triangle, X_center1, X_center2, X_center3, X_center4, X_frame1, X_frame2, X_frame3] 92 | else: 93 | return [X_img_plane, X_triangle, X_center1, X_center2, X_center3, X_center4] 94 | 95 | 96 | def create_board_model(extrinsics, board_width, board_height, square_size, draw_frame_axis=False): 97 | width = board_width*square_size 98 | height = board_height*square_size 99 | 100 | # draw calibration board 101 | X_board = np.ones((4, 5)) 102 | #X_board_cam = np.ones((extrinsics.shape[0],4,5)) 103 | X_board[0:3, 0] = [0, 0, 0] 104 | X_board[0:3, 1] = [width, 0, 0] 105 | X_board[0:3, 2] = [width, height, 0] 106 | X_board[0:3, 3] = [0, height, 0] 107 | X_board[0:3, 4] = [0, 0, 0] 108 | 109 | # draw board frame axis 110 | X_frame1 = np.ones((4, 2)) 111 | X_frame1[0:3, 0] = [0, 0, 0] 112 | X_frame1[0:3, 1] = [height/2, 0, 0] 113 | 114 | X_frame2 = np.ones((4, 2)) 115 | X_frame2[0:3, 0] = [0, 0, 0] 116 | X_frame2[0:3, 1] = [0, height/2, 0] 117 | 118 | X_frame3 = np.ones((4, 2)) 119 | X_frame3[0:3, 0] = [0, 0, 0] 120 | X_frame3[0:3, 1] = [0, 0, height/2] 121 | 122 | if draw_frame_axis: 123 | return [X_board, X_frame1, X_frame2, X_frame3] 124 | else: 125 | return [X_board] 126 | 127 | 128 | def draw_camera(ax, camera_matrix, cam_width, cam_height, scale_focal, 129 | extrinsics, 130 | patternCentric=True, 131 | annotation=True): 132 | from matplotlib import cm 133 | 134 | min_values = np.zeros((3, 1)) 135 | min_values = np.inf 136 | max_values = np.zeros((3, 1)) 137 | max_values = -np.inf 138 | 139 | X_moving = create_camera_model( 140 | camera_matrix, cam_width, cam_height, scale_focal) 141 | 142 | cm_subsection = linspace(0.0, 1.0, extrinsics.shape[0]) 143 | colors = [cm.jet(x) for x in cm_subsection] 144 | 145 | for idx in range(extrinsics.shape[0]): 146 | # R, _ = cv.Rodrigues(extrinsics[idx,0:3]) 147 | # cMo = np.eye(4,4) 148 | # cMo[0:3,0:3] = R 149 | # cMo[0:3,3] = extrinsics[idx,3:6] 150 | cMo = extrinsics[idx] 151 | for i in range(len(X_moving)): 152 | X = np.zeros(X_moving[i].shape) 153 | for j in range(X_moving[i].shape[1]): 154 | X[0:4, j] = transform_to_matplotlib_frame( 155 | cMo, X_moving[i][0:4, j], patternCentric) 156 | ax.plot3D(X[0, :], X[1, :], X[2, :], color=colors[idx]) 157 | min_values = np.minimum(min_values, X[0:3, :].min(1)) 158 | max_values = np.maximum(max_values, X[0:3, :].max(1)) 159 | # modified: add an annotation of number 160 | if annotation: 161 | X = transform_to_matplotlib_frame( 162 | cMo, X_moving[0][0:4, 0], patternCentric) 163 | ax.text(X[0], X[1], X[2], "{}".format(idx), color=colors[idx]) 164 | 165 | return min_values, max_values 166 | 167 | 168 | def visualize(camera_matrix, extrinsics): 169 | 170 | ######################## plot params ######################## 171 | cam_width = 0.064/2 # Width/2 of the displayed camera. 172 | cam_height = 0.048/2 # Height/2 of the displayed camera. 173 | scale_focal = 40 # Value to scale the focal length. 174 | 175 | ######################## original code ######################## 176 | import matplotlib.pyplot as plt 177 | from mpl_toolkits.mplot3d import Axes3D # pylint: disable=unused-variable 178 | 179 | fig = plt.figure(dpi=400) 180 | ax = fig.add_subplot(projection='3d') 181 | # ax.set_aspect("equal") 182 | ax.set_aspect("auto") 183 | 184 | min_values, max_values = draw_camera(ax, camera_matrix, cam_width, cam_height, 185 | scale_focal, extrinsics, True) 186 | 187 | X_min = min_values[0] 188 | X_max = max_values[0] 189 | Y_min = min_values[1] 190 | Y_max = max_values[1] 191 | Z_min = min_values[2] 192 | Z_max = max_values[2] 193 | max_range = np.array([X_max-X_min, Y_max-Y_min, Z_max-Z_min]).max() / 2.0 194 | 195 | mid_x = (X_max+X_min) * 0.5 196 | mid_y = (Y_max+Y_min) * 0.5 197 | mid_z = (Z_max+Z_min) * 0.5 198 | ax.set_xlim(mid_x - max_range, mid_x - max_range) 199 | ax.set_ylim(mid_y - max_range, mid_y + max_range) 200 | ax.set_zlim(mid_z - max_range, mid_z + max_range) 201 | 202 | ax.set_xlabel('x') 203 | ax.set_ylabel('y') 204 | ax.set_zlabel('z') 205 | ax.set_title('Extrinsic Parameters Visualization') 206 | 207 | plt.show() 208 | plt.savefig('./cameras.png') 209 | log.info('Done') 210 | 211 | 212 | if __name__ == '__main__': 213 | import argparse 214 | parser = argparse.ArgumentParser() 215 | parser.add_argument("--scan_id", type=int, default=40) 216 | args = parser.parse_args() 217 | 218 | log.info(__doc__) 219 | # NOTE: jianfei: 20210722 newly checked. The coordinate is correct. 220 | # note that the ticks on (-y) means the opposite of y coordinates. 221 | 222 | ######################## modified: example code ######################## 223 | from dataio.DTU import SceneDataset 224 | import torch 225 | train_dataset = SceneDataset( 226 | train_cameras=False, 227 | data_dir='./data/DTU/scan{}'.format(scan_id=args.scan_id)) 228 | c2w = torch.stack(train_dataset.c2w_all).data.cpu().numpy() 229 | extrinsics = np.linalg.inv(c2w) # camera extrinsics are w2c matrix 230 | camera_matrix = next(iter(train_dataset))[1]['intrinsics'].data.cpu().numpy() 231 | 232 | 233 | # import pickle 234 | # data = pickle.load(open('./dev_test/london/london_siren_si20_cam.pt', 'rb')) 235 | # c2ws = data['c2w'] 236 | # extrinsics = np.linalg.inv(c2ws) 237 | # camera_matrix = data['intr'] 238 | visualize(camera_matrix, extrinsics) 239 | cv.destroyAllWindows() 240 | -------------------------------------------------------------------------------- /tools/vis_ray.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import sys 4 | import torch 5 | sys.path.append('/dataset/yokoli/neurecon') 6 | from utils.rend_util import get_rays 7 | # from dataio.normalData import SceneDataset 8 | from dataio.PolData import SceneDataset 9 | import torch.nn.functional as F 10 | def lift(x, y, z, intrinsics): 11 | device = x.device 12 | # parse intrinsics 13 | intrinsics = intrinsics.to(device) 14 | fx = intrinsics[..., 0, 0] 15 | fy = intrinsics[..., 1, 1] 16 | cx = intrinsics[..., 0, 2] 17 | cy = intrinsics[..., 1, 2] 18 | sk = intrinsics[..., 0, 1] 19 | 20 | x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z 21 | y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z 22 | 23 | # homogeneous 24 | return torch.stack((x_lift, y_lift, z, torch.ones_like(z).to(device)), dim=-1) 25 | def lift_opengl(x, y, z, intrinsics): 26 | # NOTE: OpenGL convention 27 | device = x.device 28 | # parse intrinsics 29 | intrinsics = intrinsics.to(device) 30 | fx = intrinsics[..., 0, 0] 31 | fy = intrinsics[..., 1, 1] 32 | cx = intrinsics[..., 0, 2] 33 | cy = intrinsics[..., 1, 2] 34 | 35 | x_lift = (x - cx.unsqueeze(-1)) / fx.unsqueeze(-1) * z 36 | y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z 37 | 38 | # homogeneous and CONVERT TO OPENGL 39 | return torch.stack((x_lift, -y_lift, -z, torch.ones_like(z).to(device)), dim=-1) 40 | 41 | def get_center_ray(c2w, intrinsics, H, W, N_rays=1): 42 | device = c2w.device 43 | cam_loc = c2w[..., :3, 3] 44 | p = c2w 45 | 46 | prefix = p.shape[:-2] 47 | device = c2w.device 48 | i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) 49 | i = i.t().to(device).reshape([*[1]*len(prefix), H*W]).expand([*prefix, H*W]) 50 | j = j.t().to(device).reshape([*[1]*len(prefix), H*W]).expand([*prefix, H*W]) 51 | 52 | if N_rays > 0: 53 | N_rays = min(N_rays, H*W) 54 | # ---------- option 1: full image uniformly randomize 55 | # select_inds = torch.from_numpy( 56 | # np.random.choice(H*W, size=[*prefix, N_rays], replace=False)).to(device) 57 | # select_inds = torch.randint(0, H*W, size=[N_rays]).expand([*prefix, N_rays]).to(device) 58 | # ---------- option 2: H/W seperately randomize 59 | select_hs = torch.Tensor([H//2]).long() 60 | select_ws = torch.Tensor([W//2]).long() 61 | select_inds = select_hs * W + select_ws 62 | select_inds = select_inds.expand([*prefix, N_rays]) 63 | 64 | i = torch.gather(i, -1, select_inds) 65 | j = torch.gather(j, -1, select_inds) 66 | 67 | # pixel_points_cam = lift(i, j, torch.ones_like(i).to(device), intrinsics=intrinsics) 68 | 69 | pixel_points_cam = lift_opengl(i, j, torch.ones_like(i).to(device), intrinsics=intrinsics) 70 | 71 | # permute for batch matrix product 72 | pixel_points_cam = pixel_points_cam.transpose(-1,-2) 73 | 74 | # NOTE: left-multiply. 75 | # after the above permute(), shapes of coordinates changed from [B,N,4] to [B,4,N], which ensures correct left-multiplication 76 | # p is camera 2 world matrix. 77 | if len(prefix) > 0: 78 | world_coords = torch.bmm(p, pixel_points_cam).transpose(-1, -2)[..., :3] 79 | else: 80 | world_coords = torch.mm(p, pixel_points_cam).transpose(-1, -2)[..., :3] 81 | rays_d = world_coords - cam_loc[..., None, :] 82 | rays_d = F.normalize(rays_d, dim=2) 83 | 84 | rays_o = cam_loc[..., None, :].expand_as(rays_d) 85 | 86 | return rays_o, rays_d, select_inds 87 | 88 | def plot_rays(rays_o: np.ndarray, rays_d: np.ndarray, ax): 89 | # TODO: automatic reducing number of rays 90 | XYZUVW = np.concatenate([rays_o, rays_d], axis=-1) 91 | X, Y, Z, U, V, W = np.transpose(XYZUVW) 92 | # X2 = X+U 93 | # Y2 = Y+V 94 | # Z2 = Z+W 95 | # x_max = max(np.max(X), np.max(X2)) 96 | # x_min = min(np.min(X), np.min(X2)) 97 | # y_max = max(np.max(Y), np.max(Y2)) 98 | # y_min = min(np.min(Y), np.min(Y2)) 99 | # z_max = max(np.max(Z), np.max(Z2)) 100 | # z_min = min(np.min(Z), np.min(Z2)) 101 | # fig = plt.figure() 102 | # ax = fig.add_subplot(111, projection='3d') 103 | ax.quiver(X, Y, Z, U, V, W) 104 | # ax.set_xlim(x_min, x_max) 105 | # ax.set_ylim(y_min, y_max) 106 | # ax.set_zlim(z_min, z_max) 107 | 108 | return ax 109 | 110 | dataset = SceneDataset(False, '/dataset/yokoli/data/pol/mitsuba_bunny', downscale=32, scale_radius = 3, chromatic='sRGB',opengl=True) 111 | 112 | fig = plt.figure(figsize=[19.2,10.8]) 113 | ax = fig.add_subplot(111, projection='3d') 114 | ax.set_xlabel('x') 115 | ax.set_ylabel('y') 116 | ax.set_zlabel('z') 117 | ax.set_xlim(-2, 2) 118 | ax.set_ylim(-2, 2) 119 | ax.set_zlim(-2, 2) 120 | H, W = (dataset.H, dataset.W) 121 | 122 | for i in range(dataset.n_images): 123 | _, model_input, _ = dataset[i] 124 | intrinsics = model_input["intrinsics"][None, ...] 125 | c2w = model_input['c2w'][None, ...] 126 | # c2w = dataset.get_gt_pose(scaled=True) 127 | rays_o, rays_d, select_inds = get_center_ray(c2w, intrinsics, H, W, N_rays=1) 128 | rays_o = rays_o.data.squeeze(0).cpu().numpy() 129 | rays_d = rays_d.data.squeeze(0).cpu().numpy() 130 | # x y z -> x z -y 131 | # rays_o = rays_o[:,[0,2,1]] 132 | # rays_d = rays_d[:,[0,2,1]] 133 | # rays_o[:,2] = -rays_o[:,2] 134 | # rays_d[:,2] = -rays_d[:,2] 135 | ax = plot_rays(rays_o, rays_d, ax) 136 | fig.savefig('rays.png', bbox_inches='tight') -------------------------------------------------------------------------------- /tools/vis_surface_and_cam.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/newdata/yokoli/neurecon') 3 | 4 | from utils import io_util 5 | from dataio import get_data 6 | 7 | import skimage 8 | import skimage.measure 9 | import numpy as np 10 | import open3d as o3d 11 | 12 | 13 | def get_camera_frustum(img_size, K, W2C, frustum_length=0.5, color=[0., 1., 0.]): 14 | W, H = img_size 15 | hfov = np.rad2deg(np.arctan(W / 2. / K[0, 0]) * 2.) 16 | vfov = np.rad2deg(np.arctan(H / 2. / K[1, 1]) * 2.) 17 | half_w = frustum_length * np.tan(np.deg2rad(hfov / 2.)) 18 | half_h = frustum_length * np.tan(np.deg2rad(vfov / 2.)) 19 | 20 | # build view frustum for camera (I, 0) 21 | frustum_points = np.array([[0., 0., 0.], # frustum origin 22 | [-half_w, -half_h, frustum_length], # top-left image corner 23 | [half_w, -half_h, frustum_length], # top-right image corner 24 | [half_w, half_h, frustum_length], # bottom-right image corner 25 | [-half_w, half_h, frustum_length]]) # bottom-left image corner 26 | frustum_lines = np.array([[0, i] for i in range(1, 5)] + [[i, (i+1)] for i in range(1, 4)] + [[4, 1]]) 27 | frustum_colors = np.tile(np.array(color).reshape((1, 3)), (frustum_lines.shape[0], 1)) 28 | 29 | # frustum_colors = np.vstack((np.tile(np.array([[1., 0., 0.]]), (4, 1)), 30 | # np.tile(np.array([[0., 1., 0.]]), (4, 1)))) 31 | 32 | # transform view frustum from (I, 0) to (R, t) 33 | C2W = np.linalg.inv(W2C) 34 | frustum_points = np.dot(np.hstack((frustum_points, np.ones_like(frustum_points[:, 0:1]))), C2W.T) 35 | frustum_points = frustum_points[:, :3] / frustum_points[:, 3:4] 36 | 37 | return frustum_points, frustum_lines, frustum_colors 38 | 39 | def frustums2lineset(frustums): 40 | N = len(frustums) 41 | merged_points = np.zeros((N*5, 3)) # 5 vertices per frustum 42 | merged_lines = np.zeros((N*8, 2)) # 8 lines per frustum 43 | merged_colors = np.zeros((N*8, 3)) # each line gets a color 44 | 45 | for i, (frustum_points, frustum_lines, frustum_colors) in enumerate(frustums): 46 | merged_points[i*5:(i+1)*5, :] = frustum_points 47 | merged_lines[i*8:(i+1)*8, :] = frustum_lines + i*5 48 | merged_colors[i*8:(i+1)*8, :] = frustum_colors 49 | 50 | lineset = o3d.geometry.LineSet() 51 | lineset.points = o3d.utility.Vector3dVector(merged_points) 52 | lineset.lines = o3d.utility.Vector2iVector(merged_lines) 53 | lineset.colors = o3d.utility.Vector3dVector(merged_colors) 54 | 55 | return lineset 56 | 57 | 58 | # ---------------------- 59 | # plot cameras alongside with mesh 60 | # modified from NeRF++. https://github.com/Kai-46/nerfplusplus/blob/master/colmap_runner/extract_sfm.py 61 | def visualize_cameras(colored_camera_dicts, sphere_radius, camera_size=0.1, geometry_file=None, geometry_type='mesh', backface=False): 62 | things_to_draw = [] 63 | 64 | if sphere_radius > 0: 65 | sphere = o3d.geometry.TriangleMesh.create_sphere(radius=sphere_radius, resolution=10) 66 | sphere = o3d.geometry.LineSet.create_from_triangle_mesh(sphere) 67 | sphere.paint_uniform_color((1, 0, 0)) 68 | things_to_draw.append(sphere) 69 | 70 | coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5, origin=[0., 0., 0.]) 71 | things_to_draw.append(coord_frame) 72 | 73 | idx = 0 74 | for camera_dict in colored_camera_dicts: 75 | idx += 1 76 | # K = np.array(camera_dict['K']).reshape((4, 4)) 77 | K = np.array(camera_dict['K']).reshape((3, 3)) 78 | W2C = np.array(camera_dict['W2C']).reshape((4, 4)) 79 | C2W = np.linalg.inv(W2C) 80 | img_size = camera_dict['img_size'] 81 | color = camera_dict['color'] 82 | frustums = [get_camera_frustum(img_size, K, W2C, frustum_length=camera_size, color=color)] 83 | cameras = frustums2lineset(frustums) 84 | things_to_draw.append(cameras) 85 | 86 | if geometry_file is not None: 87 | if geometry_type == 'mesh': 88 | geometry = o3d.io.read_triangle_mesh(geometry_file) 89 | geometry.compute_vertex_normals() 90 | elif geometry_type == 'pointcloud': 91 | geometry = o3d.io.read_point_cloud(geometry_file) 92 | else: 93 | raise Exception('Unknown geometry_type: ', geometry_type) 94 | 95 | things_to_draw.append(geometry) 96 | if backface: 97 | o3d.visualization.RenderOption.mesh_show_back_face = True 98 | o3d.visualization.draw_geometries(things_to_draw) 99 | 100 | 101 | if __name__ == "__main__": 102 | parser = io_util.create_args_parser() 103 | parser.add_argument("--scan_id", type=int, default=40) 104 | parser.add_argument("--mesh_file", type=str, default=None) 105 | parser.add_argument("--sphere_radius", type=float, default=3.0) 106 | parser.add_argument("--backface",action='store_true', help='render show back face') 107 | args = parser.parse_args() 108 | 109 | # load camera 110 | args, unknown = parser.parse_known_args() 111 | config = io_util.load_config(args, unknown) 112 | dataset = get_data(config) 113 | 114 | #------------- 115 | colored_camera_dicts = [] 116 | for i in range(len(dataset)): 117 | (_, model_input, ground_truth) = dataset[i] 118 | c2w = model_input['c2w'].data.cpu().numpy() 119 | intrinsics = model_input["intrinsics"].data.cpu().numpy() 120 | 121 | cam_dict = {} 122 | cam_dict['img_size'] = (dataset.W, dataset.H) 123 | cam_dict['W2C'] = np.linalg.inv(c2w) 124 | cam_dict['K'] = intrinsics 125 | # cam_dict['color'] = [0, 1, 1] 126 | cam_dict['color'] = [1, 0, 0] 127 | 128 | # if i == 0: 129 | # cam_dict['color'] = [1, 0, 0] 130 | 131 | # if i == 1: 132 | # cam_dict['color'] = [0, 1, 0] 133 | 134 | # if i == 28: 135 | # cam_dict['color'] = [1, 0, 0] 136 | 137 | colored_camera_dicts.append(cam_dict) 138 | 139 | visualize_cameras(colored_camera_dicts, args.sphere_radius, geometry_file=args.mesh_file, backface=args.backface) 140 | 141 | 142 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/checkpoints.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/checkpoints.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/checkpoints.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/checkpoints.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dist_util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/dist_util.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dist_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/dist_util.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dist_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/dist_util.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dist_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/dist_util.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/general.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/general.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/general.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/general.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/io_util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/io_util.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/io_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/io_util.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/io_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/io_util.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/log_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/log_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/logger.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/mesh_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/mesh_util.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/mesh_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/mesh_util.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/plots.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/plots.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/plots.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/plots.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/print_fn.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/print_fn.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/print_fn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/print_fn.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/print_fn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/print_fn.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/print_fn.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/print_fn.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/rend_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/rend_util.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/rend_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/rend_util.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/train_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/train_util.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/train_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yukiumi13/GNeRP/b38778eead5a11efa97a5495d503a5b959ae7751/utils/__pycache__/train_util.cpython-38.pyc -------------------------------------------------------------------------------- /utils/checkpoints.py: -------------------------------------------------------------------------------- 1 | from utils.print_fn import log 2 | 3 | import os 4 | import urllib 5 | import torch 6 | from torch.utils import model_zoo 7 | 8 | # torch.autograd.set_detect_anomaly(True) 9 | 10 | class CheckpointIO(object): 11 | ''' CheckpointIO class. 12 | 13 | modified from https://github.com/LMescheder/GAN_stability/blob/master/gan_training/checkpoints.py 14 | 15 | It handles saving and loading checkpoints. 16 | 17 | Args: 18 | checkpoint_dir (str): path where checkpoints are saved 19 | ''' 20 | 21 | def __init__(self, checkpoint_dir='./chkpts', allow_mkdir=True, **kwargs): 22 | self.module_dict = kwargs 23 | self.checkpoint_dir = checkpoint_dir 24 | if allow_mkdir: 25 | if not os.path.exists(checkpoint_dir): 26 | os.makedirs(checkpoint_dir) 27 | 28 | def register_modules(self, **kwargs): 29 | ''' Registers modules in current module dictionary. 30 | ''' 31 | self.module_dict.update(kwargs) 32 | 33 | def save(self, filename, **kwargs): 34 | ''' Saves the current module dictionary. 35 | 36 | Args: 37 | filename (str): name of output file 38 | ''' 39 | if not os.path.isabs(filename): 40 | filename = os.path.join(self.checkpoint_dir, filename) 41 | log.info("=> Saving ckpt to {}".format(filename)) 42 | outdict = kwargs 43 | for k, v in self.module_dict.items(): 44 | outdict[k] = v.state_dict() 45 | torch.save(outdict, filename) 46 | log.info("Done.") 47 | 48 | def load(self, filename): 49 | '''Loads a module dictionary from local file or url. 50 | 51 | Args: 52 | filename (str): name of saved module dictionary 53 | ''' 54 | if is_url(filename): 55 | return self.load_url(filename) 56 | else: 57 | return self.load_file(filename) 58 | 59 | def load_file(self, filepath, no_reload=False, ignore_keys=[], only_use_keys=None, map_location='cuda'): 60 | '''Loads a module dictionary from file. 61 | 62 | Args: 63 | filepath (str): file path of saved module dictionary 64 | ''' 65 | 66 | assert not ((len(ignore_keys) > 0) and only_use_keys is not None), \ 67 | 'please specify at most one in [ckpt_ignore_keys, ckpt_only_use_keys]' 68 | 69 | if filepath is not None and filepath != "None": 70 | ckpts = [filepath] 71 | else: 72 | ckpts = sorted_ckpts(self.checkpoint_dir) 73 | 74 | log.info("=> Found ckpts: " + 75 | "{}".format(ckpts) if len(ckpts) < 5 else "...,{}".format(ckpts[-5:])) 76 | 77 | if len(ckpts) > 0 and not no_reload: 78 | ckpt_file = ckpts[-1] 79 | log.info('=> Loading checkpoint from local file: ' + str(ckpt_file)) 80 | state_dict = torch.load(ckpt_file, map_location=map_location) 81 | 82 | if len(ignore_keys) > 0: 83 | to_load_state_dict = {} 84 | for k in state_dict.keys(): 85 | if k in ignore_keys: 86 | log.info("=> [ckpt] Ignoring keys: {}".format(k)) 87 | else: 88 | to_load_state_dict[k] = state_dict[k] 89 | elif only_use_keys is not None: 90 | if not isinstance(only_use_keys, list): 91 | only_use_keys = [only_use_keys] 92 | to_load_state_dict = {} 93 | for k in only_use_keys: 94 | log.info("=> [ckpt] Only use keys: {}".format(k)) 95 | to_load_state_dict[k] = state_dict[k] 96 | else: 97 | to_load_state_dict = state_dict 98 | 99 | scalars = self.parse_state_dict(to_load_state_dict, ignore_keys) 100 | return scalars 101 | else: 102 | return {} 103 | 104 | def load_url(self, url): 105 | '''Load a module dictionary from url. 106 | 107 | Args: 108 | url (str): url to saved model 109 | ''' 110 | log.info(url) 111 | log.info('=> Loading checkpoint from url...') 112 | state_dict = model_zoo.load_url(url, progress=True) 113 | scalars = self.parse_state_dict(state_dict) 114 | return scalars 115 | 116 | def parse_state_dict(self, state_dict, ignore_keys): 117 | '''Parse state_dict of model and return scalars. 118 | 119 | Args: 120 | state_dict (dict): State dict of model 121 | ''' 122 | 123 | for k, v in self.module_dict.items(): 124 | if k in state_dict: 125 | v.load_state_dict(state_dict[k]) 126 | else: 127 | if k not in ignore_keys: 128 | log.info('Warning: Could not find %s in checkpoint!' % k) 129 | scalars = {k: v for k, v in state_dict.items() 130 | if k not in self.module_dict} 131 | return scalars 132 | 133 | 134 | def is_url(url): 135 | scheme = urllib.parse.urlparse(url).scheme 136 | return scheme in ('http', 'https') 137 | 138 | 139 | def sorted_ckpts(checkpoint_dir): 140 | ckpts = [] 141 | if os.path.exists(checkpoint_dir): 142 | latest = None 143 | final = None 144 | ckpts = [] 145 | for fname in sorted(os.listdir(checkpoint_dir)): 146 | fpath = os.path.join(checkpoint_dir, fname) 147 | if ".pt" in fname: 148 | ckpts.append(fpath) 149 | if 'latest' in fname: 150 | latest = fpath 151 | elif 'final' in fname: 152 | final = fpath 153 | if latest is not None: 154 | ckpts.remove(latest) 155 | ckpts.append(latest) 156 | if final is not None: 157 | ckpts.remove(final) 158 | ckpts.append(final) 159 | return ckpts -------------------------------------------------------------------------------- /utils/dist_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import numpy as np 5 | from typing import Optional 6 | 7 | import torch.distributed as dist 8 | 9 | rank = 0 # process id, for IPC 10 | local_rank = 0 # local GPU device id 11 | world_size = 1 # number of processes 12 | 13 | def init_env(args): 14 | global rank, local_rank, world_size 15 | if args.ddp: 16 | #------------- multi process running, using DDP 17 | if 'SLURM_PROCID' in os.environ: 18 | #--------- for SLURM 19 | slurm_initialize('nccl', port=args.port) 20 | else: 21 | #--------- for torch.distributed.launch 22 | dist.init_process_group(backend='nccl') 23 | 24 | rank = int(os.environ['RANK']) 25 | local_rank = int(os.environ['LOCAL_RANK']) 26 | world_size = int(os.environ['WORLD_SIZE']) 27 | torch.cuda.set_device(local_rank) 28 | args.device_ids = [local_rank] 29 | print("=> Init Env @ DDP: rank={}, world_size={}, local_rank={}.\n\tdevice_ids set to {}".format(rank, world_size, local_rank, args.device_ids)) 30 | # NOTE: important! 31 | else: 32 | #------------- single process running, using single GPU or DataParallel 33 | # torch.cuda.set_device(args.device_ids[0]) 34 | print("=> Init Env @ single process: use device_ids = {}".format(args.device_ids)) 35 | rank = 0 36 | local_rank = args.device_ids[0] 37 | world_size = 1 38 | torch.cuda.set_device(args.device_ids[0]) 39 | set_seed(42) 40 | 41 | 42 | def slurm_initialize(backend='nccl', port: Optional[int] = None): 43 | proc_id = int(os.environ['SLURM_PROCID']) 44 | ntasks = int(os.environ['SLURM_NTASKS']) 45 | node_list = os.environ['SLURM_NODELIST'] 46 | if '[' in node_list: 47 | beg = node_list.find('[') 48 | pos1 = node_list.find('-', beg) 49 | if pos1 < 0: 50 | pos1 = 1000 51 | pos2 = node_list.find(',', beg) 52 | if pos2 < 0: 53 | pos2 = 1000 54 | node_list = node_list[:min(pos1, pos2)].replace('[', '') 55 | addr = node_list[8:].replace('-', '.') 56 | if port is not None: 57 | os.environ['MASTER_PORT'] = str(port) 58 | elif 'MASTER_PORT' not in os.environ: 59 | os.environ["MASTER_PORT"] = "13333" 60 | os.environ['MASTER_ADDR'] = addr 61 | os.environ['WORLD_SIZE'] = str(ntasks) 62 | os.environ['RANK'] = str(proc_id) 63 | if backend == 'nccl': 64 | dist.init_process_group(backend='nccl') 65 | else: 66 | dist.init_process_group(backend='gloo', rank=proc_id, world_size=ntasks) 67 | rank = dist.get_rank() 68 | device = rank % torch.cuda.device_count() 69 | torch.cuda.set_device(device) 70 | os.environ['LOCAL_RANK'] = str(device) 71 | 72 | 73 | def set_seed(seed): 74 | random.seed(seed) 75 | np.random.seed(seed) 76 | torch.manual_seed(seed) 77 | torch.cuda.manual_seed_all(seed) 78 | torch.backends.cudnn.deterministic = True 79 | 80 | 81 | def is_master(): 82 | return rank == 0 83 | 84 | def get_rank(): 85 | return int(os.environ.get('SLURM_PROCID', rank)) 86 | 87 | def get_local_rank(): 88 | return int(os.environ.get('LOCAL_RANK', local_rank)) 89 | 90 | def get_world_size(): 91 | return int(os.environ.get('SLURM_NTASKS', world_size)) -------------------------------------------------------------------------------- /utils/general.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import torch 4 | 5 | def mkdir_ifnotexists(directory): 6 | if not os.path.exists(directory): 7 | os.mkdir(directory) 8 | 9 | def get_class(kls): 10 | parts = kls.split('.') 11 | module = ".".join(parts[:-1]) 12 | m = __import__(module) 13 | for comp in parts[1:]: 14 | m = getattr(m, comp) 15 | return m 16 | 17 | def glob_imgs(path): 18 | imgs = [] 19 | for ext in ['*.png', '*.jpg', '*.JPEG', '*.JPG', '*.bmp']: 20 | imgs.extend(glob(os.path.join(path, ext))) 21 | return imgs 22 | 23 | def split_input(model_input, total_pixels, n_pixels=10000): 24 | ''' 25 | Split the input to fit Cuda memory for large resolution. 26 | Can decrease the value of n_pixels in case of cuda out of memory error. 27 | ''' 28 | split = [] 29 | for i, indx in enumerate(torch.split(torch.arange(total_pixels).cuda(), n_pixels, dim=0)): 30 | data = model_input.copy() 31 | data['uv'] = torch.index_select(model_input['uv'], 1, indx) 32 | if 'object_mask' in data: 33 | data['object_mask'] = torch.index_select(model_input['object_mask'], 1, indx) 34 | split.append(data) 35 | return split 36 | 37 | def merge_output(res, total_pixels, batch_size): 38 | ''' Merge the split output. ''' 39 | 40 | model_outputs = {} 41 | for entry in res[0]: 42 | if res[0][entry] is None: 43 | continue 44 | if len(res[0][entry].shape) == 1: 45 | model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, 1) for r in res], 46 | 1).reshape(batch_size * total_pixels) 47 | else: 48 | model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, r[entry].shape[-1]) for r in res], 49 | 1).reshape(batch_size * total_pixels, -1) 50 | 51 | return model_outputs 52 | 53 | def concat_home_dir(path): 54 | return os.path.join(os.environ['HOME'],'data',path) 55 | 56 | -------------------------------------------------------------------------------- /utils/log_utils.py: -------------------------------------------------------------------------------- 1 | import prettytable as pt 2 | 3 | def pretty_table_log(field, values): 4 | tb = pt.PrettyTable() 5 | tb.field_names = field 6 | tb.add_row(values) 7 | print(tb) -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | from utils import io_util 2 | from utils.print_fn import log 3 | 4 | import os 5 | import torch 6 | import pickle 7 | import imageio 8 | import torchvision 9 | import numpy as np 10 | 11 | import torch.distributed as dist 12 | 13 | #--------------------------------------------------------------------------- 14 | #---------------------- tensorboard / image recorder ----------------------- 15 | #--------------------------------------------------------------------------- 16 | 17 | class Logger(object): 18 | """ 19 | modified from https://github.com/LMescheder/GAN_stability/blob/master/gan_training/logger.py 20 | """ 21 | def __init__(self, 22 | log_dir, 23 | img_dir, 24 | monitoring=None, 25 | monitoring_dir=None, 26 | rank=0, 27 | is_master=True, 28 | multi_process_logging=True): 29 | self.stats = dict() 30 | self.log_dir = log_dir 31 | self.img_dir = img_dir 32 | self.rank = rank 33 | self.is_master = is_master 34 | self.multi_process_logging = multi_process_logging 35 | 36 | if self.is_master: 37 | io_util.cond_mkdir(self.log_dir) 38 | io_util.cond_mkdir(self.img_dir) 39 | if self.multi_process_logging: 40 | dist.barrier() 41 | 42 | self.monitoring = None 43 | self.monitoring_dir = None 44 | 45 | # if self.is_master: 46 | 47 | # NOTE: for now, we are allowing tensorboard writting on all child processes, 48 | # as it's already nicely supported, 49 | # and the data of different events file of different processes will be automatically aggregated when visualizing. 50 | # https://discuss.pytorch.org/t/using-tensorboard-with-distributeddataparallel/102555/7 51 | if not (monitoring is None or monitoring == 'none'): 52 | self.setup_monitoring(monitoring, monitoring_dir) 53 | 54 | 55 | def setup_monitoring(self, monitoring, monitoring_dir): 56 | self.monitoring = monitoring 57 | self.monitoring_dir = monitoring_dir 58 | if monitoring == 'tensorboard': 59 | # NOTE: since torch 1.2 60 | from torch.utils.tensorboard import SummaryWriter 61 | # from tensorboardX import SummaryWriter 62 | self.tb = SummaryWriter(self.monitoring_dir) 63 | else: 64 | raise NotImplementedError('Monitoring tool "%s" not supported!' 65 | % monitoring) 66 | 67 | def add(self, category, k, v, it): 68 | if category not in self.stats: 69 | self.stats[category] = {} 70 | 71 | if k not in self.stats[category]: 72 | self.stats[category][k] = [] 73 | 74 | self.stats[category][k].append((it, v)) 75 | 76 | k_name = '%s/%s' % (category, k) 77 | if self.monitoring == 'telemetry': 78 | self.tm.metric_push_async({ 79 | 'metric': k_name, 'value': v, 'it': it 80 | }) 81 | elif self.monitoring == 'tensorboard': 82 | self.tb.add_scalar(k_name, v, it) 83 | 84 | def add_vector(self, category, k, vec, it): 85 | if category not in self.stats: 86 | self.stats[category] = {} 87 | 88 | if k not in self.stats[category]: 89 | self.stats[category][k] = [] 90 | 91 | if isinstance(vec, torch.Tensor): 92 | vec = vec.data.clone().cpu().numpy() 93 | 94 | self.stats[category][k].append((it, vec)) 95 | 96 | def add_imgs(self, imgs, class_name, it): 97 | outdir = os.path.join(self.img_dir, class_name) 98 | if self.is_master and not os.path.exists(outdir): 99 | os.makedirs(outdir) 100 | if self.multi_process_logging: 101 | dist.barrier() 102 | outfile = os.path.join(outdir, '{:08d}_{}.png'.format(it, self.rank)) 103 | 104 | # imgs = imgs / 2 + 0.5 105 | imgs = torchvision.utils.make_grid(imgs) 106 | torchvision.utils.save_image(imgs.clone(), outfile, nrow=4) 107 | 108 | if self.monitoring == 'tensorboard': 109 | self.tb.add_image(class_name, imgs, global_step=it) 110 | 111 | def add_figure(self, fig, class_name, it, save_img=True): 112 | if save_img: 113 | outdir = os.path.join(self.img_dir, class_name) 114 | if self.is_master and not os.path.exists(outdir): 115 | os.makedirs(outdir) 116 | if self.multi_process_logging: 117 | dist.barrier() 118 | outfile = os.path.join(outdir, '{:08d}_{}.png'.format(it, self.rank)) 119 | 120 | image_hwc = io_util.figure_to_image(fig) 121 | imageio.imwrite(outfile, image_hwc) 122 | if self.monitoring == 'tensorboard': 123 | if len(image_hwc.shape) == 3: 124 | image_hwc = np.array(image_hwc[None, ...]) 125 | self.tb.add_images(class_name, torch.from_numpy(image_hwc), dataformats='NHWC', global_step=it) 126 | else: 127 | if self.monitoring == 'tensorboard': 128 | self.tb.add_figure(class_name, fig, it) 129 | 130 | def add_module_param(self, module_name, module, it): 131 | if self.monitoring == 'tensorboard': 132 | for name, param in module.named_parameters(): 133 | self.tb.add_histogram("{}/{}".format(module_name, name), param.detach(), it) 134 | 135 | def get_last(self, category, k, default=0.): 136 | if category not in self.stats: 137 | return default 138 | elif k not in self.stats[category]: 139 | return default 140 | else: 141 | return self.stats[category][k][-1][1] 142 | 143 | def save_stats(self, filename): 144 | filename = os.path.join(self.log_dir, filename + '_{}'.format(self.rank)) 145 | with open(filename, 'wb') as f: 146 | pickle.dump(self.stats, f) 147 | 148 | def load_stats(self, filename): 149 | filename = os.path.join(self.log_dir, filename + '_{}'.format(self.rank)) 150 | if not os.path.exists(filename): 151 | # log.info('=> File "%s" does not exist, will create new after calling save_stats()' % filename) 152 | return 153 | 154 | try: 155 | with open(filename, 'rb') as f: 156 | self.stats = pickle.load(f) 157 | log.info("=> Load file: {}".format(filename)) 158 | except EOFError: 159 | log.info('Warning: log file corrupted!') 160 | -------------------------------------------------------------------------------- /utils/mesh_util.py: -------------------------------------------------------------------------------- 1 | from utils.print_fn import log 2 | 3 | import time 4 | import plyfile 5 | import skimage 6 | import skimage.measure 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | import torch 11 | 12 | 13 | def convert_sigma_samples_to_ply( 14 | input_3d_sigma_array: np.ndarray, 15 | voxel_grid_origin, 16 | volume_size, 17 | ply_filename_out, 18 | level=5.0, 19 | offset=None, 20 | scale=None,): 21 | """ 22 | Convert sdf samples to .ply 23 | 24 | :param input_3d_sdf_array: a float array of shape (n,n,n) 25 | :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid 26 | :volume_size: a list of three floats 27 | :ply_filename_out: string, path of the filename to save to 28 | 29 | This function adapted from: https://github.com/RobotLocomotion/spartan 30 | """ 31 | start_time = time.time() 32 | 33 | verts, faces, normals, values = skimage.measure.marching_cubes( 34 | input_3d_sigma_array, level=level, spacing=volume_size 35 | ) 36 | 37 | # transform from voxel coordinates to camera coordinates 38 | # note x and y are flipped in the output of marching_cubes 39 | mesh_points = np.zeros_like(verts) 40 | mesh_points[:, 0] = voxel_grid_origin[0] + verts[:, 0] 41 | mesh_points[:, 1] = voxel_grid_origin[1] + verts[:, 1] 42 | mesh_points[:, 2] = voxel_grid_origin[2] + verts[:, 2] 43 | 44 | # apply additional offset and scale 45 | if scale is not None: 46 | mesh_points = mesh_points / scale 47 | if offset is not None: 48 | mesh_points = mesh_points - offset 49 | 50 | # try writing to the ply file 51 | 52 | # mesh_points = np.matmul(mesh_points, np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])) 53 | # mesh_points = np.matmul(mesh_points, np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])) 54 | 55 | 56 | num_verts = verts.shape[0] 57 | num_faces = faces.shape[0] 58 | 59 | verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")]) 60 | 61 | for i in range(0, num_verts): 62 | verts_tuple[i] = tuple(mesh_points[i, :]) 63 | 64 | faces_building = [] 65 | for i in range(0, num_faces): 66 | faces_building.append(((faces[i, :].tolist(),))) 67 | faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))]) 68 | 69 | el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex") 70 | el_faces = plyfile.PlyElement.describe(faces_tuple, "face") 71 | 72 | ply_data = plyfile.PlyData([el_verts, el_faces]) 73 | log.info("saving mesh to %s" % str(ply_filename_out)) 74 | ply_data.write(ply_filename_out) 75 | 76 | log.info( 77 | "converting to ply format and writing to file took {} s".format( 78 | time.time() - start_time 79 | ) 80 | ) 81 | 82 | def extract_mesh(implicit_surface, volume_size=2.0, level=0.0, N=512, filepath='./surface.ply', show_progress=True, chunk=16*1024): 83 | s = volume_size 84 | voxel_grid_origin = [-s/2., -s/2., -s/2.] 85 | volume_size = [s, s, s] 86 | 87 | overall_index = np.arange(0, N ** 3, 1).astype(np.int_) 88 | xyz = np.zeros([N ** 3, 3]) 89 | 90 | # transform first 3 columns 91 | # to be the x, y, z index 92 | xyz[:, 2] = overall_index % N 93 | xyz[:, 1] = (overall_index / N) % N 94 | xyz[:, 0] = ((overall_index / N) / N) % N 95 | 96 | # transform first 3 columns 97 | # to be the x, y, z coordinate 98 | xyz[:, 0] = (xyz[:, 0] * (s/(N-1))) + voxel_grid_origin[2] 99 | xyz[:, 1] = (xyz[:, 1] * (s/(N-1))) + voxel_grid_origin[1] 100 | xyz[:, 2] = (xyz[:, 2] * (s/(N-1))) + voxel_grid_origin[0] 101 | 102 | def batchify(query_fn, inputs: torch.Tensor, chunk=chunk): 103 | out = [] 104 | for i in tqdm(range(0, inputs.shape[0], chunk), disable=not show_progress): 105 | out_i = query_fn(torch.from_numpy(inputs[i:i+chunk]).float().cuda()).data.cpu().numpy() 106 | out.append(out_i) 107 | out = np.concatenate(out, axis=0) 108 | return out 109 | 110 | out = batchify(implicit_surface.forward, xyz) 111 | out = out.reshape([N, N, N]) 112 | convert_sigma_samples_to_ply(out, voxel_grid_origin, [float(v) / N for v in volume_size], filepath, level=level) 113 | 114 | 115 | def extract_mesh_rgb(implicit_surface, radiance_net, volume_size=2.0, level=0.0, N=512, filepath='./surface.ply', show_progress=True, chunk=16*1024): 116 | s = volume_size 117 | voxel_grid_origin = [-s/2., -s/2., -s/2.] 118 | volume_size = [s, s, s] 119 | 120 | overall_index = np.arange(0, N ** 3, 1).astype(np.int) 121 | xyz = np.zeros([N ** 3, 3]) 122 | 123 | # transform first 3 columns 124 | # to be the x, y, z index 125 | xyz[:, 2] = overall_index % N 126 | xyz[:, 1] = (overall_index / N) % N 127 | xyz[:, 0] = ((overall_index / N) / N) % N 128 | 129 | # transform first 3 columns 130 | # to be the x, y, z coordinate 131 | xyz[:, 0] = (xyz[:, 0] * (s/(N-1))) + voxel_grid_origin[2] 132 | xyz[:, 1] = (xyz[:, 1] * (s/(N-1))) + voxel_grid_origin[1] 133 | xyz[:, 2] = (xyz[:, 2] * (s/(N-1))) + voxel_grid_origin[0] 134 | 135 | def batchify(query_fn, inputs: torch.Tensor, chunk=chunk): 136 | out = [] 137 | for i in tqdm(range(0, inputs.shape[0], chunk), disable=not show_progress): 138 | out_i = query_fn(torch.from_numpy(inputs[i:i+chunk]).float().cuda()).data.cpu().numpy() 139 | out.append(out_i) 140 | out = np.concatenate(out, axis=0) 141 | return out 142 | 143 | out = batchify(implicit_surface.forward, xyz) 144 | out = out.reshape([N, N, N]) 145 | convert_sigma_samples_to_ply(out, voxel_grid_origin, [float(v) / N for v in volume_size], filepath, level=level) -------------------------------------------------------------------------------- /utils/print_fn.py: -------------------------------------------------------------------------------- 1 | # NOTE: this file is seperated to prevent circular import 2 | 3 | import logging 4 | import sys 5 | #--------------------------------------------------------------------------- 6 | #----------------------- logging instead of printing ----------------------- 7 | #--------------------------------------------------------------------------- 8 | 9 | logs = set() 10 | # LOGGER 11 | BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8) 12 | RESET_SEQ = "\033[0m" 13 | COLOR_SEQ = "\033[1;%dm" 14 | 15 | COLORS = { 16 | 'WARNING': YELLOW, 17 | 'INFO': WHITE, 18 | 'DEBUG': BLUE, 19 | 'CRITICAL': YELLOW, 20 | 'ERROR': RED 21 | } 22 | 23 | class ColoredFormatter(logging.Formatter): 24 | def __init__(self, msg, use_color=True): 25 | logging.Formatter.__init__(self, msg) 26 | self.use_color = use_color 27 | 28 | def format(self, record): 29 | msg = record.msg 30 | levelname = record.levelname 31 | if self.use_color and levelname in COLORS and COLORS[levelname] != WHITE: 32 | if isinstance(msg, str): 33 | msg_color = COLOR_SEQ % (30 + COLORS[levelname]) + msg + RESET_SEQ 34 | record.msg = msg_color 35 | levelname_color = COLOR_SEQ % (30 + COLORS[levelname]) + levelname + RESET_SEQ 36 | record.levelname = levelname_color 37 | return logging.Formatter.format(self, record) 38 | 39 | def init_log(name, level=logging.INFO): 40 | if (name, level) in logs: 41 | return 42 | 43 | from utils.dist_util import is_master, get_rank 44 | 45 | logs.add((name, level)) 46 | logger = logging.getLogger(name) 47 | logger.setLevel(level) 48 | ch = logging.StreamHandler(stream=sys.stdout) 49 | ch.setLevel(level) 50 | 51 | logger.addFilter(lambda record: is_master()) 52 | 53 | format_str = f'%(asctime)s-rk{get_rank()}-%(filename)s#%(lineno)d:%(message)s' 54 | formatter = ColoredFormatter(format_str) 55 | ch.setFormatter(formatter) 56 | logger.addHandler(ch) 57 | 58 | logger.propagate = False 59 | 60 | return logger 61 | 62 | 63 | log = init_log('global', logging.INFO) -------------------------------------------------------------------------------- /utils/train_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Iterable 4 | 5 | def calc_grad_norm(norm_type=2.0, **named_models): 6 | gradient_norms = {'total': 0.0} 7 | for name, model in named_models.items(): 8 | gradient_norms[name] = 0.0 9 | for p in list(model.parameters()): 10 | if p.requires_grad and p.grad is not None: 11 | param_norm = p.grad.data.norm(norm_type) 12 | gradient_norms[name] += param_norm.item() ** norm_type 13 | gradient_norms['total'] += gradient_norms[name] 14 | for k, v in gradient_norms.items(): 15 | gradient_norms[k] = v ** (1.0 / norm_type) 16 | return gradient_norms 17 | 18 | def count_trainable_parameters(model): 19 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 20 | return sum([np.prod(p.size()) for p in model_parameters]) 21 | 22 | 23 | def batchify_query(query_fn, *args: Iterable[torch.Tensor], chunk, dim_batchify): 24 | '''Slice inputs and gather outputs 25 | ''' 26 | # [(B), N_rays, N_pts, ...] -> [(B), N_rays*N_pts, ...] 27 | _N_rays = args[0].shape[dim_batchify] 28 | _N_pts = args[0].shape[dim_batchify+1] 29 | args = [arg.flatten(dim_batchify, dim_batchify+1) for arg in args] 30 | _N = args[0].shape[dim_batchify] 31 | raw_ret = [] 32 | for i in range(0, _N, chunk): 33 | if dim_batchify == 0: 34 | args_i = [arg[i:i+chunk] for arg in args] 35 | elif dim_batchify == 1: 36 | args_i = [arg[:, i:i+chunk] for arg in args] 37 | elif dim_batchify == 2: 38 | args_i = [arg[:, :, i:i+chunk] for arg in args] 39 | else: 40 | raise NotImplementedError 41 | raw_ret_i = query_fn(*args_i) 42 | if not isinstance(raw_ret_i, tuple): 43 | raw_ret_i = [raw_ret_i] 44 | raw_ret.append(raw_ret_i) 45 | collate_raw_ret = [] 46 | num_entry = 0 47 | for entry in zip(*raw_ret): 48 | if isinstance(entry[0], dict): 49 | tmp_dict = {} 50 | for list_item in entry: 51 | for k, v in list_item.items(): 52 | if k not in tmp_dict: 53 | tmp_dict[k] = [] 54 | tmp_dict[k].append(v) 55 | for k in tmp_dict.keys(): 56 | # [(B), N_rays*N_pts, ...] -> [(B), N_rays, N_pts, ...] 57 | # tmp_dict[k] = torch.cat(tmp_dict[k], dim=dim_batchify).unflatten(dim_batchify, [_N_rays, _N_pts]) 58 | # NOTE: compatible with torch 1.6 59 | v = torch.cat(tmp_dict[k], dim=dim_batchify) 60 | tmp_dict[k] = v.reshape([*v.shape[:dim_batchify], _N_rays, _N_pts, *v.shape[dim_batchify+1:]]) 61 | entry = tmp_dict 62 | else: 63 | # [(B), N_rays*N_pts, ...] -> [(B), N_rays, N_pts, ...] 64 | # entry = torch.cat(entry, dim=dim_batchify).unflatten(dim_batchify, [_N_rays, _N_pts]) 65 | # NOTE: compatible with torch 1.6 66 | v = torch.cat(entry, dim=dim_batchify) 67 | entry = v.reshape([*v.shape[:dim_batchify], _N_rays, _N_pts, *v.shape[dim_batchify+1:]]) 68 | collate_raw_ret.append(entry) 69 | num_entry += 1 70 | if num_entry == 1: 71 | return collate_raw_ret[0] 72 | else: 73 | return tuple(collate_raw_ret) -------------------------------------------------------------------------------- /val.py: -------------------------------------------------------------------------------- 1 | from models.frameworks import get_model 2 | from models.base import get_optimizer, get_scheduler 3 | from utils import rend_util, train_util, mesh_util, io_util 4 | from utils.dist_util import get_local_rank, init_env, is_master, get_rank, get_world_size 5 | from utils.print_fn import log 6 | from utils.logger import Logger 7 | from utils.checkpoints import CheckpointIO 8 | from dataio import get_data 9 | 10 | import os 11 | import sys 12 | import time 13 | import functools 14 | from tqdm import tqdm 15 | import torch 16 | import torch.nn.functional as F 17 | import torch.distributed as dist 18 | from torch.utils.data.dataloader import DataLoader 19 | from torch.utils.data.distributed import DistributedSampler 20 | from torch.nn.parallel import DistributedDataParallel as DDP 21 | from torch import autograd 22 | 23 | import numpy as np 24 | import json 25 | 26 | def validate_all_normals(self): 27 | total_MAE = 0 28 | idxs = [i for i in range(self.dataset.n_images)] 29 | f = open(os.path.join(self.base_exp_dir, 'result_normal.txt'), 'a') 30 | for idx in idxs: 31 | normal_maps, color_fine = self.validate_image(idx, resolution_level=1, only_normals=True) 32 | try: 33 | GT_normal = torch.from_numpy(self.dataset.normal_np[idx]) 34 | cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6) 35 | cos_loss = cos(normal_maps.view(-1, 3), GT_normal.view(-1, 3)) 36 | cos_loss = torch.clamp(cos_loss, (-1.0 + 1e-10), (1.0 - 1e-10)) 37 | loss_rad = torch.acos(cos_loss) 38 | loss_deg = loss_rad * (180.0 / np.pi) 39 | total_MAE += loss_deg.mean() 40 | f.write(str(idx) + '_MAE:') 41 | f.write(str(loss_deg.mean().data.item()) + ' ') 42 | f.write('\n') 43 | f.flush() 44 | except: 45 | continue 46 | MAE = total_MAE / self.dataset.n_images 47 | f.write('\n') 48 | f.write('MAE_final:') 49 | f.write(str(MAE.data.item()) + ' ') 50 | f.close() 51 | 52 | def main_function(args): 53 | args.device_ids = list(range(torch.cuda.device_count())) 54 | init_env(args) 55 | #---------------------------- 56 | #-------- shortcuts --------- 57 | rank = get_rank() 58 | local_rank = get_local_rank() 59 | world_size = get_world_size() 60 | exp_dir = args.training.exp_dir 61 | mesh_dir = os.path.join(exp_dir, 'meshes') 62 | 63 | device = torch.device('cuda', local_rank) 64 | 65 | logger = Logger( 66 | log_dir=exp_dir, 67 | img_dir=os.path.join(exp_dir, 'imgs'), 68 | monitoring=args.training.get('monitoring', 'tensorboard'), 69 | monitoring_dir=os.path.join(exp_dir, 'events'), 70 | rank=rank, is_master=is_master(), multi_process_logging=(world_size > 1)) 71 | 72 | log.info("=> Experiments dir: {}".format(exp_dir)) 73 | 74 | val_dataset = get_data(args, downscale=1.0) 75 | bs = args.data.get('batch_size', None) 76 | if args.ddp: 77 | val_sampler = DistributedSampler(val_dataset) 78 | valloader = torch.utils.data.DataLoader(val_dataset, sampler=val_sampler, batch_size=bs) 79 | else: 80 | valloader = DataLoader(val_dataset, 81 | batch_size=1, 82 | shuffle=True) 83 | 84 | # Create model 85 | model, trainer, render_kwargs_train, render_kwargs_test, volume_render_fn = get_model(args) 86 | model.to(device) 87 | log.info(model) 88 | log.info("=> Nerf params: " + str(train_util.count_trainable_parameters(model))) 89 | 90 | render_kwargs_test['H'] = val_dataset.H 91 | render_kwargs_test['W'] = val_dataset.W 92 | 93 | # build optimizer 94 | optimizer = get_optimizer(args, model) 95 | 96 | # checkpoints 97 | checkpoint_io = CheckpointIO(checkpoint_dir=os.path.join(exp_dir, 'ckpts'), allow_mkdir=is_master()) 98 | if world_size > 1: 99 | dist.barrier() 100 | # Register modules to checkpoint 101 | checkpoint_io.register_modules( 102 | model=model, 103 | optimizer=optimizer, 104 | ) 105 | 106 | # Load checkpoints 107 | load_dict = checkpoint_io.load_file( 108 | args.training.ckpt_file, 109 | ignore_keys=args.training.ckpt_ignore_keys, 110 | only_use_keys=args.training.ckpt_only_use_keys, 111 | map_location=device) 112 | 113 | 114 | it = load_dict.get('global_step', 0) 115 | 116 | # pretrain if needed. must be after load state_dict, since needs 'is_pretrained' variable to be loaded. 117 | #--------------------------------------------- 118 | #-------- init perparation only done in master 119 | #--------------------------------------------- 120 | 121 | # Parallel training 122 | if args.ddp: 123 | trainer = DDP(trainer, device_ids=args.device_ids, output_device=local_rank, find_unused_parameters=False) 124 | log.info('=> Start Validating..., it={}, in {}'.format(it, exp_dir)) 125 | 126 | total_MAE = 0 127 | os.makedirs(os.path.join(exp_dir,f'imgs/eval'), exist_ok=True) 128 | f=open(os.path.join(exp_dir,f'imgs/eval/MAE.txt'), 'w+') 129 | cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6) 130 | for i in tqdm(range(len(valloader))): 131 | # for i in tqdm(range(1)): 132 | #------------------- 133 | # validate 134 | #------------------- 135 | with torch.no_grad(): 136 | (val_ind, val_in, val_gt) = next(iter(valloader)) 137 | val_ind = val_ind.item() 138 | intrinsics = val_in["intrinsics"].to(device) 139 | c2w = val_in['c2w'].to(device) 140 | 141 | # N_rays=-1 for rendering full image 142 | rays_o, rays_d, select_inds = rend_util.get_rays( 143 | c2w, intrinsics, render_kwargs_test['H'], render_kwargs_test['W'], N_rays=-1, opengl=args.data.opengl) 144 | target_rgb = val_gt['rgb'].to(device) 145 | render_kwargs_test['cone_angle'] = intrinsics[0,0,0].item() 146 | rgb, depth_v, ret = volume_render_fn(rays_o, rays_d, c2w=c2w.expand(*rays_d.shape[:-1], 4, 4), calc_normal=True, detailed_output=True, **render_kwargs_test) 147 | 148 | to_img = functools.partial( 149 | rend_util.lin2img, 150 | H=render_kwargs_test['H'], W=render_kwargs_test['W'], 151 | batched=render_kwargs_test['batched']) 152 | 153 | # logger.add_imgs(to_img(target_rgb), 'val/gt_rgb', val_ind) 154 | # logger.add_imgs(to_img(rgb), 'val/predicted_rgb', val_ind) 155 | # logger.add_imgs(to_img((rgb-target_rgb).abs()), 'val/rgb_error_map', val_ind) 156 | # logger.add_imgs(to_img((depth_v/(depth_v.max()+1e-10)).unsqueeze(-1)), 'val/pred_depth_volume', val_ind) 157 | # logger.add_imgs(to_img(ret['mask_volume'].unsqueeze(-1)), 'val/pred_mask_volume', it) 158 | # if 'depth_surface' in ret: 159 | # logger.add_imgs(to_img((ret['depth_surface']/ret['depth_surface'].max()).unsqueeze(-1)), 'val/pred_depth_surface', val_ind) 160 | # if 'mask_surface' in ret: 161 | # logger.add_imgs(to_img(ret['mask_surface'].unsqueeze(-1).float()), 'val/predicted_mask', val_ind) 162 | # if hasattr(trainer, 'val'): 163 | # trainer.val(logger, ret, to_img, it, render_kwargs_test) 164 | 165 | # ADD: Polarization Validation 166 | # if hasattr(trainer, 'val_pol') and render_kwargs_test['has_pol']: 167 | # from models.frameworks.pnr import indexing_2d_samples 168 | # AoP_map = val_gt['AoP_map'].to(device) 169 | # DoP_map = val_gt['DoP_map'].to(device) 170 | # aop_sample_idx = indexing_2d_samples(select_inds, render_kwargs_test['H'], render_kwargs_test['W'], 171 | # args.model.get('gaussian_scale_factor', 1.0)).reshape(1,-1) 172 | # aop_samples = torch.gather(val_gt['AoP_map'].to(device), 1 , aop_sample_idx.long()).reshape(*AoP_map.shape, 4) 173 | # mask = val_gt['mask'].to(device) 174 | # gt = {} 175 | # gt['AoP_map']= AoP_map 176 | # gt['DoP_map']= DoP_map 177 | # gt['rgb'] = target_rgb 178 | # gt['mask'] = mask 179 | # gt['aop_samples'] = aop_samples 180 | # trainer.val_pol(logger, ret, c2w, gt, to_img, val_ind, render_kwargs_test) 181 | 182 | # Validate Normals 183 | pred_normals = ret['normals_volume'] 184 | gt_normals = val_dataset.normals[val_ind].reshape(1,-1,3) 185 | 186 | # BGR to RGB (OpenCV Legacy) 187 | gt_normals = gt_normals[...,[2,1,0]] 188 | # Flip XZ (Mitsuba) 189 | gt_normals[...,[0,2]] *= -1 190 | 191 | num_pixel = gt_normals.shape[1] 192 | validate_mask = val_gt['mask'] 193 | pred_normals_, gt_normals_ = pred_normals[validate_mask,:].cpu(), gt_normals[validate_mask,:] 194 | cos_loss = cos(pred_normals_.view(-1, 3), gt_normals_.view(-1, 3)) 195 | cos_loss = torch.clamp(cos_loss, (-1.0 + 1e-10), (1.0 - 1e-10)) 196 | loss_rad = torch.acos(cos_loss) 197 | loss_deg = (loss_rad * (180.0 / np.pi)).sum() / num_pixel 198 | total_MAE += loss_deg 199 | pred_normals_img = (ret['normals_volume']/2.+0.5) 200 | pred_normals_img[~validate_mask,:] = 0. 201 | logger.add_imgs(to_img(pred_normals_img), 'eval/predicted_normals', val_ind) 202 | logger.add_imgs(to_img(gt_normals/2.+0.5), 'eval/gt_normals', val_ind) 203 | f.write(str(val_ind) + '_MAE:') 204 | f.write(str(loss_deg.data.item()) + ' ') 205 | f.write('\n') 206 | f.flush() 207 | 208 | MAE = total_MAE / len(valloader) 209 | f.write('\n') 210 | f.write('MAE_final:') 211 | f.write(str(MAE.data.item()) + ' ') 212 | f.close() 213 | 214 | #------------------- 215 | # validate mesh 216 | #------------------- 217 | # if is_master(): 218 | 219 | # with torch.no_grad(): 220 | # io_util.cond_mkdir(mesh_dir) 221 | # mesh_util.extract_mesh( 222 | # model.implicit_surface, 223 | # filepath=os.path.join(mesh_dir, '{:08d}.ply'.format(it)), 224 | # volume_size=args.data.get('volume_size', 2.0), 225 | # show_progress=is_master()) 226 | 227 | 228 | log.info("Everything done.") 229 | 230 | if __name__ == "__main__": 231 | # Arguments 232 | parser = io_util.create_args_parser() 233 | parser.add_argument("--ddp", action='store_true', help='whether to use DDP to train.') 234 | parser.add_argument("--port", type=int, default=None, help='master port for multi processing. (if used)') 235 | args, unknown = parser.parse_known_args() 236 | config = io_util.load_config(args, unknown, base_config_path= None if args.base_config is None else args.base_config) 237 | main_function(config) -------------------------------------------------------------------------------- /vis_weights.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | from models.frameworks import get_model 4 | from utils import rend_util, train_util, mesh_util, io_util 5 | import os, json, random 6 | import numpy as np 7 | 8 | data_dir = 'data/pol/ceramicCat' 9 | idx = 1 10 | H = 2048 11 | W=2448 12 | 13 | config_path_1 = 'logs/VolSDF/baseline/ceramicCat/config.yaml' 14 | config_path_2 = 'logs/PNeuS/ceramicCat/wrgb_0.1/config.yaml' 15 | 16 | 17 | ckpt_path_1 = os.path.join(*config_path_1.split('/')[:-1],'ckpts/latest.pt') 18 | print(ckpt_path_1) 19 | 20 | ckpt_path_2 = os.path.join(*config_path_2.split('/')[:-1],'ckpts/latest.pt') 21 | print(ckpt_path_2) 22 | 23 | config_1 = io_util.load_yaml(config_path_1) 24 | config_1.device_ids = [0] 25 | config_2 = io_util.load_yaml(config_path_2) 26 | config_2.device_ids = [0] 27 | model_1, trainer_1, render_kwargs_train_1, render_kwargs_test_1, volume_render_fn_1 = get_model(config_1) 28 | model_1.cuda() 29 | model_1.eval() 30 | model_2, trainer_2, render_kwargs_train_2, render_kwargs_test_2, volume_render_fn_2 = get_model(config_2) 31 | model_2.cuda() 32 | model_2.eval() 33 | 34 | state_dict = torch.load(ckpt_path_1) 35 | model_1.load_state_dict(state_dict['model']) 36 | state_dict = torch.load(ckpt_path_2) 37 | model_2.load_state_dict(state_dict['model']) 38 | 39 | with open(os.path.join(data_dir, 'camera_intrinsics.json'), 'r') as f: 40 | camera_intrinsics = json.load(f)['intrinsics'] 41 | 42 | with open(os.path.join(data_dir, 'camera_extrinsics.json'), 'r') as f: 43 | camera_extrinsics = json.load(f) 44 | 45 | name_idx = list(camera_extrinsics.keys())[idx] 46 | print(name_idx) 47 | rotation = torch.Tensor(camera_extrinsics[name_idx]['rotation']).transpose(1, 0).float() # R^T 48 | translation = torch.Tensor(camera_extrinsics[name_idx]['camera_pos']).float() # C = -R_transpose*t 49 | c2w_ = torch.cat([rotation, translation.unsqueeze(1)], dim=1) # 3 x 4 50 | c2w = torch.cat([c2w_, torch.Tensor([[0.,0.,0.,1.]])], dim=0) 51 | cam_center_norms = np.linalg.norm(translation.numpy()) 52 | c2w_1 = c2w * (3.0 / cam_center_norms) 53 | c2w_2 = c2w * (2.0 / cam_center_norms) 54 | intrinsics = torch.Tensor(camera_intrinsics) 55 | rays_o_1, rays_d_1, rays_o_2, rays_d_2, select_inds = rend_util.get_birays( 56 | c2w_1, c2w_2, intrinsics, H, W, N_rays=512) 57 | rays_o_1 = rays_o_1.cuda()[None,...] 58 | rays_d_1 = rays_d_1.cuda()[None,...] 59 | rays_o_2 = rays_o_2.cuda()[None,...] 60 | rays_d_2 = rays_d_2.cuda()[None,...] 61 | 62 | rgb, depth_v, extras_1 = volume_render_fn_1(rays_o_1, rays_d_1, detailed_output=True, **render_kwargs_train_1) 63 | rgb, depth_v, extras_2 = volume_render_fn_2(rays_o_2, rays_d_2, detailed_output=True, **render_kwargs_train_2) 64 | 65 | fig = plt.figure() 66 | alphas_1 = extras_1['alpha'] 67 | sdfs_1 = extras_1['implicit_surface'] 68 | weights_1 = extras_1['visibility_weights'] 69 | try: 70 | cdfs_1 = extras_1['cdf'] 71 | except: 72 | cdfs_1 = extras_1['sigma'] 73 | try: 74 | depths_1 = extras_1['d_final'] 75 | except: 76 | depths_1 = extras_1['d_vals'] 77 | try: 78 | beta_1 = extras_1['beta_warp'] 79 | except: 80 | beta_1 = None 81 | 82 | alphas_2 = extras_2['alpha'] 83 | sdfs_2 = extras_2['implicit_surface'] 84 | weights_2 = extras_2['visibility_weights'] 85 | try: 86 | cdfs_2 = extras_2['cdf'] 87 | except: 88 | cdfs_2 = extras_2['sigma'] 89 | try: 90 | depths_2 = extras_2['d_final'] 91 | except: 92 | depths_2 = extras_2['d_vals'] 93 | try: 94 | beta_2 = extras_2['beta_warp'] 95 | except: 96 | beta_2 = None 97 | 98 | N = sdfs_1.shape[-1] - 1 99 | 100 | ind_1 = random.randint(0, N) 101 | #depths = torch.arange(N + 32) 102 | 103 | alpha_1 = alphas_1[0][ind_1][:127].detach().cpu().numpy() 104 | alpha_1 = np.concatenate([alpha_1, np.ones(1)], axis=-1) 105 | weight_1 = weights_1[0][ind_1][:127].detach().cpu().numpy() 106 | weight_1 = np.concatenate([weight_1, np.zeros(1)], axis=-1) 107 | sdf_1 = sdfs_1[0][ind_1][:127].detach().cpu().numpy() 108 | #cdf_1 = torch.cat((cdfs[0][ind_1], cdfs[0][ind_1][-1] *torch.ones((31)).cuda()), dim=-1) 109 | cdf_1 = cdfs_1[0][ind_1][:127].detach().cpu().numpy() 110 | depths_1 = depths_1[0][ind_1][:127].detach().cpu().numpy() 111 | 112 | alpha_2 = alphas_2[0][ind_1][:127].detach().cpu().numpy() 113 | weight_2 = weights_2[0][ind_1][:127].detach().cpu().numpy() 114 | sdf_2 = sdfs_2[0][ind_1][:127].detach().cpu().numpy() 115 | #cdf_1 = torch.cat((cdfs[0][ind_1], cdfs[0][ind_1][-1] *torch.ones((31)).cuda()), dim=-1) 116 | cdf_2 = cdfs_2[0][ind_1][:127].detach().cpu().numpy() 117 | depths_2 = depths_2[0][ind_1][:127].detach().cpu().numpy() 118 | 119 | if beta_2 is not None: 120 | beta_2 = beta_2[0][ind_1][:127].detach().cpu().numpy() 121 | 122 | fig, ax = plt.subplots(1, 2, figsize=(12, 6), dpi=120) 123 | # ax[0, 0].scatter(x=depths, y=sdf_1) 124 | # ax[0, 0].set_xlabel('Depth', fontsize=16) 125 | # ax[0, 0].set_ylabel('sdf', fontsize=16) 126 | # ax[0, 0].set_title('sdf', fontsize=16) 127 | 128 | # if beta is not None: 129 | # ax[0, 1].scatter(x=depths, y=beta_1) 130 | # ax[0, 1].set_xlabel('Depth', fontsize=16) 131 | # ax[0, 1].set_ylabel('beta_warp', fontsize=16) 132 | # ax[0, 1].set_title('beta_warp', fontsize=16) 133 | # else: 134 | # ax[0, 1].scatter(x=depths, y=cdf_1) 135 | # ax[0, 1].set_xlabel('Depth', fontsize=16) 136 | # ax[0, 1].set_ylabel('cdf', fontsize=16) 137 | # ax[0, 1].set_title('cdf', fontsize=16) 138 | 139 | # ax[1, 0].scatter(x=depths, y=alpha_1) 140 | # ax[1, 0].set_xlabel('Depth', fontsize=16) 141 | # ax[1, 0].set_ylabel('alpha', fontsize=16) 142 | # ax[1, 0].set_title('alpha', fontsize=16) 143 | 144 | ax[0].scatter(x=depths_1, y=weight_1) 145 | ax[0].set_xlabel('Depth', fontsize=16) 146 | ax[0].set_xticks([]) 147 | ax[0].set_ylabel('weight', fontsize=16) 148 | ax[0].set_ylim(0,0.5) 149 | ax[0].set_title('VolSDF', fontsize=16) 150 | ax[1].scatter(x=depths_2, y=weight_2) 151 | ax[1].set_xlabel('Depth', fontsize=16) 152 | ax[1].set_xticks([]) 153 | ax[1].set_ylabel('weight', fontsize=16) 154 | ax[1].set_ylim(0,0.5) 155 | ax[1].set_title('PNeuS', fontsize=16) 156 | 157 | 158 | fig.savefig('./weights_vis.png') 159 | --------------------------------------------------------------------------------